def create_model_and_criterion(args): ''' Creating a model of predetermined architecture. :return: model and its criterion ''' # create model model = models.__dict__[args.model] model_config = {'dataset': args.dataset} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) model = model(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # load checkpoint # model.load_state_dict(checkpoint['state_dict']) # logging.info("loaded checkpoint '%s' (epoch %s)", # args.eval_path, checkpoint['epoch']) # if args.absorb_bn: # search_absorb_bn(model, remove_bn=not args.calibrate_bn, verbose=True) # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) return model, criterion
def main(): global args, best_prec1, dtype best_prec1 = 0 args = parser.parse_args() dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' if args.save is '': args.save = time_stamp save_path = os.path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0): os.makedirs(save_path) setup_logging(os.path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = os.path.join(save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] model_config = {'dataset': args.dataset} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) model = model(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # optionally resume from a checkpoint if args.evaluate: if not os.path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate) model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate, checkpoint['epoch']) elif args.resume: checkpoint_file = args.resume if os.path.isdir(checkpoint_file): results.load(os.path.join(checkpoint_file, 'results.csv')) checkpoint_file = os.path.join(checkpoint_file, 'model_best.pth.tar') if os.path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file) if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint['epoch'] - 1 best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, checkpoint['epoch']) else: logging.error("no checkpoint found at '%s'", args.resume) # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{ 'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay }]) optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) trainer = Trainer(model, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, distributed=args.distributed, local_rank=args.local_rank, grad_clip=args.grad_clip, print_freq=args.print_freq, adapt_grad_norm=args.adapt_grad_norm) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) if args.evaluate: results = trainer.validate(val_data.get_loader()) logging.info(results) return # Training Data loading code train_data = DataRegime(getattr(model, 'data_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'cutout': { 'holes': 1, 'length': 16 } if args.cutout else None }) logging.info('optimization regime: %s', optim_regime) args.start_epoch = max(args.start_epoch, 0) trainer.training_steps = args.start_epoch * len(train_data) for epoch in range(args.start_epoch, args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) # train for one epoch train_results = trainer.train(train_data.get_loader(), duplicates=train_data.get('duplicates'), chunk_batch=args.chunk_batch) # evaluate on validation set val_results = trainer.validate(val_data.get_loader()) if args.distributed and args.local_rank > 0: continue # remember best prec@1 and save checkpoint is_best = val_results['prec1'] > best_prec1 best_prec1 = max(val_results['prec1'], best_prec1) save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'config': args.model_config, 'state_dict': model.state_dict(), 'best_prec1': best_prec1 }, is_best, path=save_path) logging.info('\nResults - Epoch: {0}\n' 'Training Loss {train[loss]:.4f} \t' 'Training Prec@1 {train[prec1]:.3f} \t' 'Training Prec@5 {train[prec5]:.3f} \t' 'Validation Loss {val[loss]:.4f} \t' 'Validation Prec@1 {val[prec1]:.3f} \t' 'Validation Prec@5 {val[prec5]:.3f} \t\n'.format( epoch + 1, train=train_results, val=val_results)) values = dict(epoch=epoch + 1, steps=trainer.training_steps) values.update({'training ' + k: v for k, v in train_results.items()}) values.update({'validation ' + k: v for k, v in val_results.items()}) results.add(**values) results.plot(x='epoch', y=['training loss', 'validation loss'], legend=['training', 'validation'], title='Loss', ylabel='loss') results.plot(x='epoch', y=['training error1', 'validation error1'], legend=['training', 'validation'], title='Error@1', ylabel='error %') results.plot(x='epoch', y=['training error5', 'validation error5'], legend=['training', 'validation'], title='Error@5', ylabel='error %') if 'grad' in train_results.keys(): results.plot(x='epoch', y=['training grad'], legend=['gradient L2 norm'], title='Gradient Norm', ylabel='value') results.save()
def main_worker(args): global best_prec1, dtype best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' if args.save is '': args.save = time_stamp save_path = os.path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0): os.makedirs(save_path) setup_logging(os.path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = os.path.join(save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None if not os.path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate, map_location="cpu") # Overrride configuration with checkpoint info args.model = checkpoint.get('model', args.model) args.model_config = checkpoint.get('config', args.model_config) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) # create model model = models.__dict__[args.model] model_config = {'dataset': args.dataset} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) model = model(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # load checkpoint ################ if args.pretrained: state_dict = load_state_dict_from_url(model_urls[args.model], progress=progress) model.load_state_dict(state_dict) # model = imagenet_extra_models.__dict__[arch](pretrained=pretrained) // from distiller else: model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate, checkpoint['epoch']) ########### model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate, checkpoint['epoch']) if args.absorb_bn: search_absorb_bn(model, remove_bn=not args.calibrate_bn, verbose=True) # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) trainer = Trainer(model, criterion, device_ids=args.device_ids, device=args.device, dtype=dtype, mixup=args.mixup, print_freq=args.print_freq) # Evaluation Data loading code val_data = DataRegime(None, defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': args.augment, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': False, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': { 'holes': 1, 'length': 16 } if args.cutout else None, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) if args.calibrate_bn: train_data = DataRegime(None, defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) trainer.calibrate_bn(train_data.get_loader(), num_steps=200) results = trainer.validate(val_data.get_loader(), average_output=args.avg_out) logging.info(results) print(results) return results
def main_worker(args): global best_prec1, dtype best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' if args.save is '': args.save = time_stamp save_path = path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] # if not (args.distributed and args.local_rank > 0): if not path.exists(save_path): makedirs(save_path) dump_args(args, path.join(save_path, 'args.txt')) setup_logging(path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=False) results_path = path.join(save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # All parameters to the model should be passed via this dict. model_config = { 'dataset': args.dataset, 'dp_type': args.dropout_type, 'dp_percentage': args.dropout_perc, 'dropout': args.drop_rate, 'device': args.device } if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) # create Resnet model model = resnet(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # # optionally resume from a checkpoint # if args.evaluate: # if not path.isfile(args.evaluate): # parser.error('invalid checkpoint: {}'.format(args.evaluate)) # checkpoint = torch.load(args.evaluate, map_location="cpu") # # Overrride configuration with checkpoint info # args.model = checkpoint.get('model', args.model) # args.model_config = checkpoint.get('config', args.model_config) # # load checkpoint # model.load_state_dict(checkpoint['state_dict']) # logging.info("loaded checkpoint '%s' (epoch %s)", # args.evaluate, checkpoint['epoch']) # # if args.resume: # checkpoint_file = args.resume # if path.isdir(checkpoint_file): # results.load(path.join(checkpoint_file, 'results.csv')) # checkpoint_file = path.join( # checkpoint_file, 'model_best.pth.tar') # if path.isfile(checkpoint_file): # logging.info("loading checkpoint '%s'", args.resume) # checkpoint = torch.load(checkpoint_file, map_location="cpu") # if args.start_epoch < 0: # not explicitly set # args.start_epoch = checkpoint['epoch'] # best_prec1 = checkpoint['best_prec1'] # model.load_state_dict(checkpoint['state_dict']) # optim_state_dict = checkpoint.get('optim_state_dict', None) # logging.info("loaded checkpoint '%s' (epoch %s)", # checkpoint_file, checkpoint['epoch']) # else: # logging.error("no checkpoint found at '%s'", args.resume) # else: # optim_state_dict = None # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{ 'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay }]) optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) # if optim_state_dict is not None: # optimizer.load_state_dict(optim_state_dict) trainer = Trainer(model, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, loss_scale=args.loss_scale, grad_clip=args.grad_clip, print_freq=args.print_freq, adapt_grad_norm=args.adapt_grad_norm) if args.tensorwatch: trainer.set_watcher(filename=path.abspath( path.join(save_path, 'tensorwatch.log')), port=args.tensorwatch_port) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) if args.evaluate: results = trainer.validate(val_data.get_loader()) logging.info(results) return # Training Data loading code train_data_defaults = { 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': { 'holes': 1, 'length': 16 } if args.cutout else None } if hasattr(model, 'sampled_data_regime'): sampled_data_regime = model.sampled_data_regime probs, regime_configs = zip(*sampled_data_regime) regimes = [] for config in regime_configs: defaults = {**train_data_defaults} defaults.update(config) regimes.append(DataRegime(None, defaults=defaults)) train_data = SampledDataRegime(regimes, probs) else: train_data = DataRegime(getattr(model, 'data_regime', None), defaults=train_data_defaults) logging.info('optimization regime: %s', optim_regime) args.start_epoch = max(args.start_epoch, 0) trainer.training_steps = args.start_epoch * len(train_data) for epoch in range(args.start_epoch, args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) # train for one epoch train_results = trainer.train(train_data.get_loader(), chunk_batch=args.chunk_batch) # evaluate on validation set val_results = trainer.validate(val_data.get_loader()) # # save weights heatmap # w = model._modules['layer3']._modules['5']._modules['conv2']._parameters['weight'].view(64, -1).cpu().detach().numpy() # heat_maps_dir = 'C:\\Users\\Pavel\\Desktop\\targeted_dropout_pytorch\\pics\\experiment_0' # plot = sns.heatmap(w, center=0) # name = str(datetime.now()).replace(':', '_').replace('-', '_').replace('.', '_').replace(' ', '_') + '.png' # plot.get_figure().savefig(path.join(heat_maps_dir, name)) # plt.clf() if args.distributed and args.local_rank > 0: continue # remember best prec@1 and save checkpoint is_best = val_results['prec1'] > best_prec1 best_prec1 = max(val_results['prec1'], best_prec1) if args.drop_optim_state: optim_state_dict = None else: optim_state_dict = optimizer.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'config': args.model_config, 'state_dict': model.state_dict(), 'optim_state_dict': optim_state_dict, 'best_prec1': best_prec1 }, is_best, path=save_path, save_all=False) logging.info('\nResults - Epoch: {0}\n' 'Training Loss {train[loss]:.4f} \t' 'Training Prec@1 {train[prec1]:.3f} \t' 'Training Prec@5 {train[prec5]:.3f} \t' 'Validation Loss {val[loss]:.4f} \t' 'Validation Prec@1 {val[prec1]:.3f} \t' 'Validation Prec@5 {val[prec5]:.3f} \t\n'.format( epoch + 1, train=train_results, val=val_results)) values = dict(epoch=epoch + 1, steps=trainer.training_steps) values.update({'training ' + k: v for k, v in train_results.items()}) values.update({'validation ' + k: v for k, v in val_results.items()}) results.add(**values) results.plot(x='epoch', y=['training loss', 'validation loss'], legend=['training', 'validation'], title='Loss', ylabel='loss') results.plot(x='epoch', y=['training error1', 'validation error1'], legend=['training', 'validation'], title='Error@1', ylabel='error %') results.plot(x='epoch', y=['training error5', 'validation error5'], legend=['training', 'validation'], title='Error@5', ylabel='error %') if 'grad' in train_results.keys(): results.plot(x='epoch', y=['training grad'], legend=['gradient L2 norm'], title='Gradient Norm', ylabel='value') results.save()
def main_worker(args, ml_logger): global best_prec1, dtype best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not (args.distributed and args.local_rank > 0): if not path.exists(args.save_path): makedirs(args.save_path) export_args_namespace(args, path.join(args.save_path, 'config.json')) setup_logging(path.join(args.save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = path.join(args.save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", args.save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] model_config = {'dataset': args.dataset} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) model = model(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) if args.resume: checkpoint_file = args.resume if path.isdir(checkpoint_file): results.load(path.join(checkpoint_file, 'results.csv')) checkpoint_file = path.join(checkpoint_file, 'model_best.pth.tar') if path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file, map_location="cpu") if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optim_state_dict = checkpoint.get('optim_state_dict', None) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, checkpoint['epoch']) else: logging.error("no checkpoint found at '%s'", args.resume) else: optim_state_dict = None # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{ 'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay }]) optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) if optim_state_dict is not None: optimizer.load_state_dict(optim_state_dict) trainer = Trainer( model, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, print_freq=args.print_freq, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, cutmix=args.cutmix, loss_scale=args.loss_scale, grad_clip=args.grad_clip, adapt_grad_norm=args.adapt_grad_norm, ) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) # Training Data loading code train_data_defaults = { 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': { 'holes': 1, 'length': 16 } if args.cutout else None } if hasattr(model, 'sampled_data_regime'): sampled_data_regime = model.sampled_data_regime probs, regime_configs = zip(*sampled_data_regime) regimes = [] for config in regime_configs: defaults = {**train_data_defaults} defaults.update(config) regimes.append(DataRegime(None, defaults=defaults)) train_data = SampledDataRegime(regimes, probs) else: train_data = DataRegime(getattr(model, 'data_regime', None), defaults=train_data_defaults) logging.info('optimization regime: %s', optim_regime) logging.info('data regime: %s', train_data) args.start_epoch = max(args.start_epoch, 0) trainer.training_steps = args.start_epoch * len(train_data) if 'zeroBN' in model_config: #hot start num_steps = int(len(train_data.get_loader()) * 0.5) trainer.train(train_data.get_loader(), chunk_batch=args.chunk_batch, num_steps=num_steps) for m in model.modules(): if isinstance(m, ZeroBN): m.max_sparsity = args.max_sparsity m.max_cos_sim = args.max_cos_sim if args.preserve_cosine: if args.layers_cos_sim1 in m.fullName: m.preserve_cosine = args.preserve_cosine m.cos_sim = args.cos_sim1 if args.layers_cos_sim2 in m.fullName: m.preserve_cosine = args.preserve_cosine m.cos_sim = args.cos_sim2 if args.layers_cos_sim3 in m.fullName: m.preserve_cosine = args.preserve_cosine m.cos_sim = args.cos_sim3 if args.min_cos_sim: if args.layers_min_cos_sim1 in m.fullName: m.min_cos_sim = args.min_cos_sim m.cos_sim_min = args.cos_sim_min1 if args.layers_min_cos_sim2 in m.fullName: m.min_cos_sim = args.min_cos_sim m.cos_sim_min = args.cos_sim_min2 for epoch in range(args.start_epoch, args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) if 'zeroBN' in model_config: trainer.collectStat(train_data.get_loader(), num_steps=1, prunRatio=args.stochasticPrunning, cos_sim=args.cos_sim, cos_sim_max=args.cos_sim_max) trainer.collectStat(train_data.get_loader(), num_steps=1, prunRatio=args.stochasticPrunning, cos_sim=args.cos_sim, cos_sim_max=args.cos_sim_max) # torch.cuda.empty_cache() train_results = trainer.train(train_data.get_loader(), ml_logger, chunk_batch=args.chunk_batch) # evaluate on validation set val_results = trainer.validate(val_data.get_loader()) ml_logger.log_metric('Val Acc1', val_results['prec1'], step='auto') # torch.cuda.empty_cache() if args.distributed and args.local_rank > 0: continue # remember best prec@1 and save checkpoint is_best = val_results['prec1'] > best_prec1 best_prec1 = max(val_results['prec1'], best_prec1) if args.drop_optim_state: optim_state_dict = None else: optim_state_dict = optimizer.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'config': args.model_config, 'state_dict': model.state_dict(), 'optim_state_dict': optim_state_dict, 'best_prec1': best_prec1 }, is_best, path=args.save_path, save_all=args.save_all) logging.info('\nResults - Epoch: {0}\n' 'Training Loss {train[loss]:.4f} \t' 'Training Prec@1 {train[prec1]:.3f} \t' 'Training Prec@5 {train[prec5]:.3f} \t' 'Validation Loss {val[loss]:.4f} \t' 'Validation Prec@1 {val[prec1]:.3f} \t' 'Validation Prec@5 {val[prec5]:.3f} \t\n'.format( epoch + 1, train=train_results, val=val_results)) values = dict(epoch=epoch + 1, steps=trainer.training_steps) values.update({'training ' + k: v for k, v in train_results.items()}) values.update({'validation ' + k: v for k, v in val_results.items()}) results.add(**values) results.plot(x='epoch', y=['training loss', 'validation loss'], legend=['training', 'validation'], title='Loss', ylabel='loss') results.plot(x='epoch', y=['training error1', 'validation error1'], legend=['training', 'validation'], title='Error@1', ylabel='error %') results.plot(x='epoch', y=['training error5', 'validation error5'], legend=['training', 'validation'], title='Error@5', ylabel='error %') if 'grad' in train_results.keys(): results.plot(x='epoch', y=['training grad'], legend=['gradient L2 norm'], title='Gradient Norm', ylabel='value') results.save()
def main_worker(args): global best_prec1, dtype acc = -1 loss = -1 best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' if args.save is '': args.save = time_stamp save_path = os.path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0): os.makedirs(save_path) setup_logging(os.path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = os.path.join(save_path, 'results') results = ResultsLog( results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] dataset_type = 'imagenet' if 'imagenet_calib' in args.dataset else args.dataset model_config = {'dataset': dataset_type} if args.model_config is not '': if isinstance(args.model_config, dict): for k, v in args.model_config.items(): if k not in model_config.keys(): model_config[k] = v else: args_dict = literal_eval(args.model_config) for k, v in args_dict.items(): model_config[k] = v if (args.absorb_bn or args.load_from_vision or args.pretrained) and not args.batch_norn_tuning: if args.load_from_vision: import torchvision exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)' model = eval(exec_lfv_str) else: if not os.path.isfile(args.absorb_bn): parser.error('invalid checkpoint: {}'.format(args.evaluate)) model = model(**model_config) checkpoint = torch.load(args.absorb_bn,map_location=lambda storage, loc: storage) checkpoint = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint sd={} for key in checkpoint.keys(): key_clean=key.split('module.1.')[1] sd[key_clean]=checkpoint[key] checkpoint = sd model.load_state_dict(checkpoint,strict=False) if args.load_from_vision or ('batch_norm' in model_config and model_config['batch_norm']): logging.info('Creating absorb_bn state dict') search_absorbe_bn(model) filename_ab = args.absorb_bn+'.absorb_bn' if args.absorb_bn else save_path+'/'+args.model+'.absorb_bn' torch.save(model.state_dict(),filename_ab) if not args.load_from_vision: return else: filename_bn = save_path+'/'+args.model+'.with_bn' torch.save(model.state_dict(),filename_bn) if args.load_from_vision: return model = model(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # optionally resume from a checkpoint if args.evaluate: if not os.path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate, map_location="cpu") # Overrride configuration with checkpoint info args.model = checkpoint.get('model', args.model) #args.model_config = checkpoint.get('config', args.model_config) if not model_config['batch_norm']: search_absorbe_fake_bn(model) # load checkpoint if 'state_dict' in checkpoint.keys(): if any([True for key in checkpoint['state_dict'].keys() if 'module.1.' in key]): sd={} for key in checkpoint['state_dict'].keys(): key_clean=key.split('module.1.')[1] sd[key_clean]=checkpoint['state_dict'][key] model.load_state_dict(sd,strict=False) else: model.load_state_dict(checkpoint['state_dict'],strict=False) logging.info("loaded checkpoint '%s'", args.evaluate) else: model.load_state_dict(checkpoint,strict=False) logging.info("loaded checkpoint '%s'",args.evaluate) if args.resume: checkpoint_file = args.resume if os.path.isdir(checkpoint_file): results.load(os.path.join(checkpoint_file, 'results.csv')) checkpoint_file = os.path.join( checkpoint_file, 'model_best.pth.tar') if os.path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file) if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint['epoch'] - 1 if 'epoch' in checkpoint.keys() else 0 best_prec1 = checkpoint['best_prec1'] if 'best_prec1' in checkpoint.keys() else -1 sd = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint model.load_state_dict(sd,strict=False) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, args.start_epoch) else: logging.error("no checkpoint found at '%s'", args.resume) # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay}]) if args.fine_tune or args.prune: if not args.resume: args.start_epoch=0 if args.update_only_th: #optim_regime = [ # {'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-4}] optim_regime = [ {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1}, {'epoch': 10, 'lr': 1e-2}, {'epoch': 15, 'lr': 1e-3}] else: optim_regime = [ {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-4, 'momentum': 0.9}, {'epoch': 2, 'lr': 1e-5, 'momentum': 0.9}, {'epoch': 10, 'lr': 1e-6, 'momentum': 0.9}] optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) # Training Data loading code train_data = DataRegime(getattr(model, 'data_regime', None), defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': False, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': {'holes': 1, 'length': 16} if args.cutout else None}) prunner = None trainer = Trainer(model,prunner, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, loss_scale=args.loss_scale, grad_clip=args.grad_clip, print_freq=args.print_freq, adapt_grad_norm=args.adapt_grad_norm,epoch=args.start_epoch) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size dataset_type = 'imagenet' if 'imagenet_calib' in args.dataset else args.dataset #dataset_type = 'imagenet' if args.dataset =='imagenet_calib' else args.dataset val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={'datasets_path': args.datasets_dir, 'name': dataset_type, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False}) cached_input_output = {} cached_layer_names = [] if args.adaprune: generate_masks_from_model(model) def hook(module, input, output): if module not in cached_input_output: cached_input_output[module] = [] # Meanwhile store data in the RAM. cached_input_output[module].append((input[0].detach().cpu(), output.detach().cpu())) print(module.__str__()[:70]) handlers = [] count = 0 global_val = calc_global_prune_val(model,args.sparsity_level) if args.global_pruning else None for name,m in model.named_modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): m.quantize = False if count < 1000: handlers.append(m.register_forward_hook(hook)) count += 1 cached_layer_names.append(name) # Store input/output for all quantizable layers trainer.validate(train_data.get_loader(),num_steps=1) print("Input/outputs cached") for handler in handlers: handler.remove() mse_df = pd.DataFrame(index=np.arange(len(cached_input_output)), columns=['name', 'shape', 'mse_before', 'mse_after']) print_freq = 100 masks_dict = {} for i, layer in enumerate(cached_input_output): layer.name = cached_layer_names[i] print("\nOptimize {}:{} for shape of {}".format(i, layer.name , layer.weight.shape)) sparsity_level = 1 if args.keep_first_last and (i==0 or i==len(cached_input_output)) else args.sparsity_level prune_topk = args.prune_bs if args.keep_first_last and (i==0 or i==len(cached_input_output)) else args.prune_topk mse_before, mse_after, snr_before, snr_after, kurt_in, kurt_w, mask= \ optimize_layer(layer, cached_input_output[layer], args.optimize_weights,bs=args.prune_bs,topk=prune_topk,extract_topk=args.prune_extract_topk, \ unstructured=args.unstructured,sparsity_level=sparsity_level,global_val=global_val,conf_level=args.conf_level) masks_dict[layer.name ] = mask print("\nMSE before optimization: {}".format(mse_before)) print("MSE after optimization: {}".format(mse_after)) mse_df.loc[i, 'name'] = layer.name mse_df.loc[i, 'shape'] = str(layer.weight.shape) mse_df.loc[i, 'mse_before'] = mse_before mse_df.loc[i, 'mse_after'] = mse_after mse_df.loc[i, 'snr_before'] = snr_before mse_df.loc[i, 'snr_after'] = snr_after mse_df.loc[i, 'kurt_in'] = kurt_in mse_df.loc[i, 'kurt_w'] = kurt_w if i > 0 and i % print_freq == 0: print('\n') val_results = trainer.validate(val_data.get_loader()) logging.info(val_results) total_sparsity = calc_masks_sparsity(masks_dict) mse_csv = args.evaluate + '.mse.csv' mse_df.to_csv(mse_csv) filename = args.evaluate + '.adaprune' torch.save(model.state_dict(), filename) cached_input_output = None val_results = trainer.validate(val_data.get_loader()) logging.info(val_results) trainer.cal_bn_stats(train_data.get_loader()) val_results = trainer.validate(val_data.get_loader()) logging.info(val_results) elif args.batch_norn_tuning: for m in model.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): m.quantize = False exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)' model_orig = eval(exec_lfv_str) model_orig.to(args.device, dtype) search_copy_bn_params(model_orig) layers_orig = dict([(n, m) for n, m in model_orig.named_modules() if isinstance(m, nn.Conv2d)]) layers_q = dict([(n, m) for n, m in model.named_modules() if isinstance(m, nn.Conv2d)]) for l in layers_orig: conv_orig = layers_orig[l] conv_q = layers_q[l] conv_q.register_parameter('gamma', nn.Parameter(conv_orig.gamma.clone())) conv_q.register_parameter('beta', nn.Parameter(conv_orig.beta.clone())) del model_orig search_add_bn(model) print("Run BN tuning") for tt in range(args.tuning_iter): print(tt) trainer.cal_bn_stats(train_data.get_loader()) search_absorbe_tuning_bn(model) filename = args.evaluate + '.bn_tuning' print("Save model to: {}".format(filename)) torch.save(model.state_dict(), filename) val_results = trainer.validate(val_data.get_loader()) logging.info(val_results) if args.res_log is not None: if not os.path.exists(args.res_log): df = pd.DataFrame() else: df = pd.read_csv(args.res_log, index_col=0) ckp = ntpath.basename(args.evaluate) df.loc[ckp, 'acc_bn_tuning'] = val_results['prec1'] df.loc[ckp, 'loss_bn_tuning'] = val_results['loss'] df.to_csv(args.res_log) else: #print('Please Choose one of the following ....') if model_config['measure']: results = trainer.validate(train_data.get_loader(),rec=args.rec) else: if args.evaluate_init_configuration: results = trainer.validate(val_data.get_loader()) if args.res_log is not None: if not os.path.exists(args.res_log): df = pd.DataFrame() else: df = pd.read_csv(args.res_log, index_col=0) ckp = ntpath.basename(args.evaluate) df.loc[ckp, 'acc_base'] = results['prec1'] df.loc[ckp, 'loss_base'] = results['loss'] df.to_csv(args.res_log) if args.evaluate_init_configuration: logging.info(results) return acc, loss
def main_worker(args): global best_prec1, dtype best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' model_config = {'dataset': args.dataset, 'batch': args.batch_size} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) ##autoname fname = auto_name(args, model_config) args.save = fname monitor = args.monitor print(fname) save_path = path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not (args.distributed and args.local_rank > 0): if not args.dry: if not path.exists(save_path): makedirs(save_path) export_args_namespace(args, path.join(save_path, 'config.json')) if monitor > 0 and not args.dry: events_path = "runs/%s" % fname my_file = Path(events_path) if my_file.is_file(): os.remove(events_path) writer = SummaryWriter(log_dir=events_path ,comment=str(args)) model_config['writer'] = writer model_config['monitor'] = monitor else: monitor = 0 writer = None if args.dry: model = models.__dict__[args.model] model = model(**model_config) print("created model with configuration: %s" % model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) print("number of parameters: %d" % num_parameters) return setup_logging(path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = path.join(save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] model = model(**model_config) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # optionally resume from a checkpoint if args.evaluate: if not path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate, map_location="cpu") # Overrride configuration with checkpoint info args.model = checkpoint.get('model', args.model) args.model_config = checkpoint.get('config', args.model_config) # load checkpoint model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate, checkpoint['epoch']) if args.resume: checkpoint_file = args.resume if path.isdir(checkpoint_file): results.load(path.join(checkpoint_file, 'results.csv')) checkpoint_file = path.join( checkpoint_file, 'model_best.pth.tar') if path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file, map_location="cpu") if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optim_state_dict = checkpoint.get('optim_state_dict', None) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, checkpoint['epoch']) else: logging.error("no checkpoint found at '%s'", args.resume) else: optim_state_dict = None # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay}]) optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) if optim_state_dict is not None: optimizer.load_state_dict(optim_state_dict) trainer = Trainer(model, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, print_freq=args.print_freq, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, cutmix=args.cutmix, loss_scale=args.loss_scale, grad_clip=args.grad_clip, adapt_grad_norm=args.adapt_grad_norm, writer = writer, monitor = monitor) if args.tensorwatch: trainer.set_watcher(filename=path.abspath(path.join(save_path, 'tensorwatch.log')), port=args.tensorwatch_port) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False}) if args.evaluate: results = trainer.validate(val_data.get_loader()) logging.info(results) return # Training Data loading code train_data_defaults = {'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': {'holes': 1, 'length': 16} if args.cutout else None} if hasattr(model, 'sampled_data_regime'): sampled_data_regime = model.sampled_data_regime probs, regime_configs = zip(*sampled_data_regime) regimes = [] for config in regime_configs: defaults = {**train_data_defaults} defaults.update(config) regimes.append(DataRegime(None, defaults=defaults)) train_data = SampledDataRegime(regimes, probs) else: train_data = DataRegime( getattr(model, 'data_regime', None), defaults=train_data_defaults) logging.info('optimization regime: %s', optim_regime) logging.info('data regime: %s', train_data) args.start_epoch = max(args.start_epoch, 0) trainer.training_steps = args.start_epoch * len(train_data) if not args.covmat == "": try: int_covmat = int(args.covmat) if int_covmat < 0: total_layers = len([name for name, layer in model.named_children()]) int_covmat = total_layers + int_covmat child_cnt = 0 except ValueError: int_covmat = None def calc_covmat(x_, partitions = 64): L = x_.shape[0] // partitions non_diags = [] diags = [] for p1 in range(partitions): for p2 in range(partitions): x = x_[p1*L:(p1+1)*L] y = x_[p2*L:(p2+1)*L] X = torch.matmul(x,y.transpose(0,1)) if p1 == p2: mask = torch.eye(X.shape[0],dtype=torch.bool) non_diag = X[~mask].reshape(-1).cpu() diag = X[mask].reshape(-1).cpu() non_diags.append(non_diag) diags.append(diag) else: non_diag = X.reshape(-1).cpu() non_diags.append(diag) diags = torch.cat(diags) non_diags = torch.cat(non_diags) diag_var = diags.var() non_diag_var = non_diags.var() diags = diags - diags.mean() non_diags = non_diags - non_diags.mean() ##import pdb; pdb.set_trace() diag_small_ratio = (diags < -diags.std()).to(dtype = torch.float).mean() non_diag_small_ratio = (non_diags < -non_diags.std()).to(dtype = torch.float).mean() return diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio global diag_var_mean global non_diag_var_mean global var_count var_count = 0 diag_var_mean = 0 non_diag_var_mean = 0 def report_covmat_hook(module, input, output): global diag_var_mean global non_diag_var_mean global var_count flatten_output = output.reshape([-1,1]).detach() diag_var, non_diag_var, diag_small_ratio, non_diag_small_ratio = calc_covmat(flatten_output) diag_var_mean = diag_var_mean + diag_var non_diag_var_mean = non_diag_var_mean + non_diag_var var_count = var_count + 1 if var_count % 10 == 1: print("diag_var = %.02f (%.02f), ratio: %.02f , non_diag_var = %0.2f (%.02f), ratio: %.02f" % (diag_var, diag_var_mean/var_count, diag_small_ratio , non_diag_var, non_diag_var_mean/var_count, non_diag_small_ratio )) for name, layer in model.named_children(): if int_covmat is None: condition = (name == args.covmat) else: condition = (child_cnt == int_covmat) child_cnt = child_cnt + 1 if condition: layer.register_forward_hook( report_covmat_hook) for epoch in range(args.start_epoch, args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) # train for one epoch train_results = trainer.train(train_data.get_loader(), chunk_batch=args.chunk_batch) # evaluate on validation set val_results = trainer.validate(val_data.get_loader()) if args.distributed and args.local_rank > 0: continue # remember best prec@1 and save checkpoint is_best = val_results['prec1'] > best_prec1 best_prec1 = max(val_results['prec1'], best_prec1) if args.drop_optim_state: optim_state_dict = None else: optim_state_dict = optimizer.state_dict() save_checkpoint({ 'epoch': epoch + 1, 'model': args.model, 'config': args.model_config, 'state_dict': model.state_dict(), 'optim_state_dict': optim_state_dict, 'best_prec1': best_prec1 }, is_best, path=save_path, save_all=args.save_all) logging.info('\nResults - Epoch: {0}\n' 'Training Loss {train[loss]:.4f} \t' 'Training Prec@1 {train[prec1]:.3f} \t' 'Training Prec@5 {train[prec5]:.3f} \t' 'Validation Loss {val[loss]:.4f} \t' 'Validation Prec@1 {val[prec1]:.3f} \t' 'Validation Prec@5 {val[prec5]:.3f} \t\n' .format(epoch + 1, train=train_results, val=val_results)) if writer is not None: writer.add_scalar('Train/Loss', train_results['loss'], epoch) writer.add_scalar('Train/Prec@1', train_results['prec1'], epoch) writer.add_scalar('Train/Prec@5', train_results['prec5'], epoch) writer.add_scalar('Val/Loss', val_results['loss'], epoch) writer.add_scalar('Val/Prec@1', val_results['prec1'], epoch) writer.add_scalar('Val/Prec@5', val_results['prec5'], epoch) # tmplr = optimizer.get_lr() # writer.add_scalar('HyperParameters/learning-rate', tmplr, epoch) values = dict(epoch=epoch + 1, steps=trainer.training_steps) values.update({'training ' + k: v for k, v in train_results.items()}) values.update({'validation ' + k: v for k, v in val_results.items()}) results.add(**values) results.plot(x='epoch', y=['training loss', 'validation loss'], legend=['training', 'validation'], title='Loss', ylabel='loss') results.plot(x='epoch', y=['training error1', 'validation error1'], legend=['training', 'validation'], title='Error@1', ylabel='error %') results.plot(x='epoch', y=['training error5', 'validation error5'], legend=['training', 'validation'], title='Error@5', ylabel='error %') if 'grad' in train_results.keys(): results.plot(x='epoch', y=['training grad'], legend=['gradient L2 norm'], title='Gradient Norm', ylabel='value') results.save() logging.info(f'\nBest Validation Accuracy (top1): {best_prec1}') if writer: writer.close()
def main_worker(args): global best_prec1, dtype best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' if args.save is '': args.save = time_stamp save_path = path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not (args.distributed and args.local_rank > 0): if not path.exists(save_path): makedirs(save_path) export_args_namespace(args, path.join(save_path, 'config.json')) setup_logging(path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = path.join(save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) grad_stats_path = path.join(save_path, 'grad_stats') grad_stats = ResultsLog(grad_stats_path, title='collect grad stats - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] model_config = {'dataset': args.dataset} if args.model_config is not '': model_config = dict(model_config, **literal_eval(args.model_config)) if args.enable_scheduler: model_config['fp8_dynamic'] = True if args.smart_loss_scale_only: model_config['smart_loss_scale_only'] = True if args.smart_loss_scale_and_exp_bits: model_config['smart_loss_scale_and_exp_bits'] = True model = model(**model_config) quantize_modules_name = [ n for n, m in model.named_modules() if isinstance(m, nn.Conv2d) ] fp8_scheduler = FP8TrainingScheduler( model, model_config, args, collect_stats_online=False, start_to_collect_stats_in_epoch=3, collect_stats_every_epochs=10, online_update=False, first_update_with_stats_from_epoch=4, start_online_update_in_epoch=3, update_every_epochs=1, update_loss_scale=True, update_exp_bit_width=args.smart_loss_scale_and_exp_bits, stats_path= "/data/moran/ConvNet_lowp_0/convNet.pytorch/results/2020-05-16_01-44-22/results.csv", # ResNet18- cifar10 # stats_path = "/data/moran/ConvNet_lowp_0/convNet.pytorch/results/2020-05-19_01-27-57/results.csv", # ResNet18- ImageNet quantize_modules_name=quantize_modules_name, enable_scheduler=False) if args.sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # optionally resume from a checkpoint if args.evaluate: if not path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate, map_location="cpu") # Overrride configuration with checkpoint info args.model = checkpoint.get('model', args.model) args.model_config = checkpoint.get('config', args.model_config) # load checkpoint model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate, checkpoint['epoch']) if args.resume: checkpoint_file = args.resume if path.isdir(checkpoint_file): results.load(path.join(checkpoint_file, 'results.csv')) checkpoint_file = path.join(checkpoint_file, 'model_best.pth.tar') if path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file, map_location="cpu") if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optim_state_dict = checkpoint.get('optim_state_dict', None) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, checkpoint['epoch']) else: logging.error("no checkpoint found at '%s'", args.resume) else: optim_state_dict = None # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{ 'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay }]) optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) if optim_state_dict is not None: optimizer.load_state_dict(optim_state_dict) trainer = Trainer(model, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, print_freq=args.print_freq, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, cutmix=args.cutmix, loss_scale=args.loss_scale, grad_clip=args.grad_clip, adapt_grad_norm=args.adapt_grad_norm, enable_input_grad_statistics=True, exp_bits=args.exp_bits, fp_bits=args.fp_bits) if args.tensorwatch: trainer.set_watcher(filename=path.abspath( path.join(save_path, 'tensorwatch.log')), port=args.tensorwatch_port) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': False, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) if args.evaluate: results = trainer.validate(val_data.get_loader()) logging.info(results) return # Training Data loading code train_data_defaults = { 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': { 'holes': 1, 'length': 16 } if args.cutout else None } if hasattr(model, 'sampled_data_regime'): sampled_data_regime = model.sampled_data_regime probs, regime_configs = zip(*sampled_data_regime) regimes = [] for config in regime_configs: defaults = {**train_data_defaults} defaults.update(config) regimes.append(DataRegime(None, defaults=defaults)) train_data = SampledDataRegime(regimes, probs) else: train_data = DataRegime(getattr(model, 'data_regime', None), defaults=train_data_defaults) logging.info('optimization regime: %s', optim_regime) logging.info('data regime: %s', train_data) args.start_epoch = max(args.start_epoch, 0) trainer.training_steps = args.start_epoch * len(train_data) for epoch in range(args.start_epoch, args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) fp8_scheduler.schedule_before_epoch(epoch) # train for one epoch # pdb.set_trace() train_results, meters_grad = trainer.train( train_data.get_loader(), chunk_batch=args.chunk_batch, scheduled_instructions=fp8_scheduler.scheduled_instructions) # evaluate on validation set if args.calibrate_bn: train_data = DataRegime(None, defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': True, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) trainer.calibrate_bn(train_data.get_loader(), num_steps=200) val_results, _ = trainer.validate(val_data.get_loader()) if args.distributed and args.local_rank > 0: continue # remember best prec@1 and save checkpoint is_best = val_results['prec1'] > best_prec1 best_prec1 = max(val_results['prec1'], best_prec1) if args.drop_optim_state: optim_state_dict = None else: optim_state_dict = optimizer.state_dict() save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'config': args.model_config, 'state_dict': model.state_dict(), 'optim_state_dict': optim_state_dict, 'best_prec1': best_prec1 }, is_best, path=save_path, save_all=args.save_all) logging.info('\nResults - Epoch: {0}\n' 'Training Loss {train[loss]:.4f} \t' 'Training Prec@1 {train[prec1]:.3f} \t' 'Training Prec@5 {train[prec5]:.3f} \t' 'Validation Loss {val[loss]:.4f} \t' 'Validation Prec@1 {val[prec1]:.3f} \t' 'Validation Prec@5 {val[prec5]:.3f} \t\n'.format( epoch + 1, train=train_results, val=val_results)) values = dict(epoch=epoch + 1, steps=trainer.training_steps) values.update({'training ' + k: v for k, v in train_results.items()}) values.update({'validation ' + k: v for k, v in val_results.items()}) values.update( {'grad mean ' + k: v['mean'].avg for k, v in meters_grad.items()}) values.update( {'grad std ' + k: v['std'].avg for k, v in meters_grad.items()}) results.add(**values) # stats was collected if fp8_scheduler.scheduled_instructions['collect_stat']: grad_stats_values = dict(epoch=epoch + 1) grad_stats_values.update({ 'grad mean ' + k: v['mean'].avg for k, v in meters_grad.items() }) grad_stats_values.update({ 'grad std ' + k: v['std'].avg for k, v in meters_grad.items() }) grad_stats.add(**grad_stats_values) fp8_scheduler.update_stats(grad_stats) results.plot(x='epoch', y=['training loss', 'validation loss'], legend=['training', 'validation'], title='Loss', ylabel='loss') results.plot(x='epoch', y=['training error1', 'validation error1'], legend=['training', 'validation'], title='Error@1', ylabel='error %') results.plot(x='epoch', y=['training error5', 'validation error5'], legend=['training', 'validation'], title='Error@5', ylabel='error %') if 'grad' in train_results.keys(): results.plot(x='epoch', y=['training grad'], legend=['gradient L2 norm'], title='Gradient Norm', ylabel='value') results.save() grad_stats.save()
def main_worker(args): global best_prec1, dtype acc = -1 loss = -1 best_prec1 = 0 dtype = torch_dtypes.get(args.dtype) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if args.evaluate: args.results_dir = '/tmp' if args.save is '': args.save = time_stamp save_path = os.path.join(args.results_dir, args.save) args.distributed = args.local_rank >= 0 or args.world_size > 1 if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_init, world_size=args.world_size, rank=args.local_rank) args.local_rank = dist.get_rank() args.world_size = dist.get_world_size() if args.dist_backend == 'mpi': # If using MPI, select all visible devices args.device_ids = list(range(torch.cuda.device_count())) else: args.device_ids = [args.local_rank] if not os.path.exists(save_path) and not (args.distributed and args.local_rank > 0): os.makedirs(save_path) setup_logging(os.path.join(save_path, 'log.txt'), resume=args.resume is not '', dummy=args.distributed and args.local_rank > 0) results_path = os.path.join(save_path, 'results') results = ResultsLog(results_path, title='Training Results - %s' % args.save) logging.info("saving to %s", save_path) logging.debug("run arguments: %s", args) logging.info("creating model %s", args.model) if 'cuda' in args.device and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.cuda.set_device(args.device_ids[0]) cudnn.benchmark = True else: args.device_ids = None # create model model = models.__dict__[args.model] dataset_type = 'imagenet' if args.dataset == 'imagenet_calib' else args.dataset model_config = {'dataset': dataset_type} if args.model_config is not '': if isinstance(args.model_config, dict): for k, v in args.model_config.items(): if k not in model_config.keys(): model_config[k] = v else: args_dict = literal_eval(args.model_config) for k, v in args_dict.items(): model_config[k] = v if (args.absorb_bn or args.load_from_vision or args.pretrained) and not args.batch_norn_tuning: if args.load_from_vision: import torchvision exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)' model = eval(exec_lfv_str) if 'pytcv' in args.model: from pytorchcv.model_provider import get_model as ptcv_get_model exec_lfv_str = 'ptcv_get_model("' + args.load_from_vision + '", pretrained=True)' model_pytcv = eval(exec_lfv_str) model = convert_pytcv_model(model, model_pytcv) else: if not os.path.isfile(args.absorb_bn): parser.error('invalid checkpoint: {}'.format(args.evaluate)) model = model(**model_config) checkpoint = torch.load(args.absorb_bn, map_location=lambda storage, loc: storage) checkpoint = checkpoint[ 'state_dict'] if 'state_dict' in checkpoint.keys( ) else checkpoint model.load_state_dict(checkpoint, strict=False) if 'batch_norm' in model_config and not model_config['batch_norm']: logging.info('Creating absorb_bn state dict') search_absorbe_bn(model) filename_ab = args.absorb_bn + '.absorb_bn' if args.absorb_bn else save_path + '/' + args.model + '.absorb_bn' torch.save(model.state_dict(), filename_ab) else: filename_bn = save_path + '/' + args.model + '.with_bn' torch.save(model.state_dict(), filename_bn) if (args.load_from_vision or args.absorb_bn) and not args.evaluate_init_configuration: return if 'inception' in args.model: model = model(init_weights=False, **model_config) else: model = model(**model_config) logging.info("created model with configuration: %s", model_config) num_parameters = sum([l.nelement() for l in model.parameters()]) logging.info("number of parameters: %d", num_parameters) # optionally resume from a checkpoint if args.evaluate: if not os.path.isfile(args.evaluate): parser.error('invalid checkpoint: {}'.format(args.evaluate)) checkpoint = torch.load(args.evaluate, map_location="cpu") # Overrride configuration with checkpoint info args.model = checkpoint.get('model', args.model) args.model_config = checkpoint.get('config', args.model_config) if not model_config['batch_norm']: search_absorbe_fake_bn(model) # load checkpoint if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict']) logging.info("loaded checkpoint '%s'", args.evaluate) else: model.load_state_dict(checkpoint, strict=False) logging.info("loaded checkpoint '%s'", args.evaluate) if args.resume: checkpoint_file = args.resume if os.path.isdir(checkpoint_file): results.load(os.path.join(checkpoint_file, 'results.csv')) checkpoint_file = os.path.join(checkpoint_file, 'model_best.pth.tar') if os.path.isfile(checkpoint_file): logging.info("loading checkpoint '%s'", args.resume) checkpoint = torch.load(checkpoint_file) if args.start_epoch < 0: # not explicitly set args.start_epoch = checkpoint[ 'epoch'] - 1 if 'epoch' in checkpoint.keys() else 0 best_prec1 = checkpoint[ 'best_prec1'] if 'best_prec1' in checkpoint.keys() else -1 sd = checkpoint['state_dict'] if 'state_dict' in checkpoint.keys( ) else checkpoint model.load_state_dict(sd, strict=False) logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file, args.start_epoch) else: logging.error("no checkpoint found at '%s'", args.resume) # define loss function (criterion) and optimizer loss_params = {} if args.label_smoothing > 0: loss_params['smooth_eps'] = args.label_smoothing criterion = getattr(model, 'criterion', CrossEntropyLoss)(**loss_params) if args.kld_loss: criterion = nn.KLDivLoss(reduction='mean') criterion.to(args.device, dtype) model.to(args.device, dtype) # Batch-norm should always be done in float if 'half' in args.dtype: FilterModules(model, module=is_bn).to(dtype=torch.float) # optimizer configuration optim_regime = getattr(model, 'regime', [{ 'epoch': 0, 'optimizer': args.optimizer, 'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay }]) if args.fine_tune or args.prune: if not args.resume: args.start_epoch = 0 if args.update_only_th: #optim_regime = [ # {'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-4}] optim_regime = [{ 'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1 }, { 'epoch': 10, 'lr': 1e-2 }, { 'epoch': 15, 'lr': 1e-3 }] else: optim_regime = [{ 'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-4, 'momentum': 0.9 }, { 'epoch': 2, 'lr': 1e-5, 'momentum': 0.9 }, { 'epoch': 10, 'lr': 1e-6, 'momentum': 0.9 }] optimizer = optim_regime if isinstance(optim_regime, OptimRegime) \ else OptimRegime(model, optim_regime, use_float_copy='half' in args.dtype) # Training Data loading code train_data = DataRegime(getattr(model, 'data_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': args.dataset, 'split': 'train', 'augment': False, 'input_size': args.input_size, 'batch_size': args.batch_size, 'shuffle': not args.seq_adaquant, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': True, 'distributed': args.distributed, 'duplicates': args.duplicates, 'autoaugment': args.autoaugment, 'cutout': { 'holes': 1, 'length': 16 } if args.cutout else None, 'inception_prep': 'inception' in args.model }) if args.names_sp_layers is None and args.layers_precision_dict is None: args.names_sp_layers = [ key[:-7] for key in model.state_dict().keys() if 'weight' in key and 'running' not in key and ( 'conv' in key or 'downsample.0' in key or 'fc' in key) ] if args.keep_first_last: args.names_sp_layers = [ name for name in args.names_sp_layers if name != 'conv1' and name != 'fc' and name != 'Conv2d_1a_3x3.conv' ] args.names_sp_layers = [ k for k in args.names_sp_layers if 'downsample' not in k ] if args.ignore_downsample else args.names_sp_layers if args.num_sp_layers == 0 and not args.keep_first_last: args.names_sp_layers = [] if args.layers_precision_dict is not None: print(args.layers_precision_dict) prunner = None trainer = Trainer(model, prunner, criterion, optimizer, device_ids=args.device_ids, device=args.device, dtype=dtype, distributed=args.distributed, local_rank=args.local_rank, mixup=args.mixup, loss_scale=args.loss_scale, grad_clip=args.grad_clip, print_freq=args.print_freq, adapt_grad_norm=args.adapt_grad_norm, epoch=args.start_epoch, update_only_th=args.update_only_th, optimize_rounding=args.optimize_rounding) # Evaluation Data loading code args.eval_batch_size = args.eval_batch_size if args.eval_batch_size > 0 else args.batch_size dataset_type = 'imagenet' if args.dataset == 'imagenet_calib' else args.dataset val_data = DataRegime(getattr(model, 'data_eval_regime', None), defaults={ 'datasets_path': args.datasets_dir, 'name': dataset_type, 'split': 'val', 'augment': False, 'input_size': args.input_size, 'batch_size': args.eval_batch_size, 'shuffle': True, 'num_workers': args.workers, 'pin_memory': True, 'drop_last': False }) if args.evaluate or args.resume: from utils.layer_sensativity import search_replace_layer, extract_save_quant_state_dict, search_replace_layer_from_dict if args.layers_precision_dict is not None: model = search_replace_layer_from_dict( model, ast.literal_eval(args.layers_precision_dict)) else: model = search_replace_layer(model, args.names_sp_layers, num_bits_activation=args.nbits_act, num_bits_weight=args.nbits_weight) cached_input_output = {} quant_keys = [ '.weight', '.bias', '.equ_scale', '.quantize_input.running_zero_point', '.quantize_input.running_range', '.quantize_weight.running_zero_point', '.quantize_weight.running_range', '.quantize_input1.running_zero_point', '.quantize_input1.running_range' '.quantize_input2.running_zero_point', '.quantize_input2.running_range' ] if args.adaquant: def Qhook(name, module, input, output): if module not in cached_qinput: cached_qinput[module] = [] # Meanwhile store data in the RAM. cached_qinput[module].append(input[0].detach().cpu()) # print(name) def hook(name, module, input, output): if module not in cached_input_output: cached_input_output[module] = [] # Meanwhile store data in the RAM. cached_input_output[module].append( (input[0].detach().cpu(), output.detach().cpu())) # print(name) from models.modules.quantize import QConv2d, QLinear handlers = [] count = 0 for name, m in model.named_modules(): if isinstance(m, QConv2d) or isinstance(m, QLinear): #if isinstance(m, QConv2d) or isinstance(m, QLinear): # if isinstance(m, QConv2d): m.quantize = False if count < 1000: # if (isinstance(m, QConv2d) and m.groups == 1) or isinstance(m, QLinear): handlers.append( m.register_forward_hook(partial(hook, name))) count += 1 # Store input/output for all quantizable layers trainer.validate(train_data.get_loader()) print("Input/outputs cached") for handler in handlers: handler.remove() for m in model.modules(): if isinstance(m, QConv2d) or isinstance(m, QLinear): m.quantize = True mse_df = pd.DataFrame( index=np.arange(len(cached_input_output)), columns=['name', 'bit', 'shape', 'mse_before', 'mse_after']) print_freq = 100 for i, layer in enumerate(cached_input_output): if i > 0 and args.seq_adaquant: count = 0 cached_qinput = {} for name, m in model.named_modules(): if layer.name == name: if count < 1000: handler = m.register_forward_hook( partial(Qhook, name)) count += 1 # Store input/output for all quantizable layers trainer.validate(train_data.get_loader()) print("cashed quant Input%s" % layer.name) cached_input_output[layer][0] = ( cached_qinput[layer][0], cached_input_output[layer][0][1]) handler.remove() print("\nOptimize {}:{} for {} bit of shape {}".format( i, layer.name, layer.num_bits, layer.weight.shape)) mse_before, mse_after, snr_before, snr_after, kurt_in, kurt_w = \ optimize_layer(layer, cached_input_output[layer], args.optimize_weights, batch_size=args.batch_size, model_name=args.model) print("\nMSE before optimization: {}".format(mse_before)) print("MSE after optimization: {}".format(mse_after)) mse_df.loc[i, 'name'] = layer.name mse_df.loc[i, 'bit'] = layer.num_bits mse_df.loc[i, 'shape'] = str(layer.weight.shape) mse_df.loc[i, 'mse_before'] = mse_before mse_df.loc[i, 'mse_after'] = mse_after mse_df.loc[i, 'snr_before'] = snr_before mse_df.loc[i, 'snr_after'] = snr_after mse_df.loc[i, 'kurt_in'] = kurt_in mse_df.loc[i, 'kurt_w'] = kurt_w mse_csv = args.evaluate + '.mse.csv' mse_df.to_csv(mse_csv) filename = args.evaluate + '.adaquant' torch.save(model.state_dict(), filename) train_data = None cached_input_output = None val_results = trainer.validate(val_data.get_loader()) logging.info(val_results) if args.res_log is not None: if not os.path.exists(args.res_log): df = pd.DataFrame() else: df = pd.read_csv(args.res_log, index_col=0) ckp = ntpath.basename(args.evaluate) if args.cmp is not None: ckp += '_{}'.format(args.cmp) adaquant_type = 'adaquant_seq' if args.seq_adaquant else 'adaquant_parallel' df.loc[ckp, 'acc_' + adaquant_type] = val_results['prec1'] df.to_csv(args.res_log) # print(df) elif args.per_layer: # Store input/output for all quantizable layers calib_all_8_results = trainer.validate(train_data.get_loader()) print('########## All 8bit results ###########', calib_all_8_results) int8_opt_model_state_dict = torch.load(args.int8_opt_model_path) int4_opt_model_state_dict = torch.load(args.int4_opt_model_path) per_layer_results = {} args.names_sp_layers = [ key[:-7] for key in model.state_dict().keys() if 'weight' in key and 'running' not in key and 'quantize' not in key and ('conv' in key or 'downsample.0' in key or 'fc' in key) ] for layer_idx, layer in enumerate(args.names_sp_layers): model.load_state_dict(int8_opt_model_state_dict, strict=False) model = search_replace_layer(model, [layer], num_bits_activation=args.nbits_act, num_bits_weight=args.nbits_weight) layer_keys = [ key for key in int8_opt_model_state_dict for qpkey in quant_keys if layer + qpkey == key ] for key in layer_keys: model.state_dict()[key].copy_(int4_opt_model_state_dict[key]) calib_results = trainer.validate(train_data.get_loader()) model = search_replace_layer(model, [layer], num_bits_activation=8, num_bits_weight=8) print('finished %d out of %d' % (layer_idx, len(args.names_sp_layers))) logging.info(layer) logging.info(calib_results) per_layer_results[layer] = { 'base precision': 8, 'replaced precision': args.nbits_act, 'replaced layer': layer, 'accuracy': calib_results['prec1'], 'loss': calib_results['loss'], 'Parameters Size [Elements]': model.state_dict()[layer + '.weight'].numel(), 'MACs': '-' } torch.save( per_layer_results, args.evaluate + '.per_layer_accuracy.A' + str(args.nbits_act) + '.W' + str(args.nbits_weight)) all_8_dict = { 'base precision': 8, 'replaced precision': args.nbits_act, 'replaced layer': '-', 'accuracy': calib_all_8_results['prec1'], 'loss': calib_all_8_results['loss'], 'Parameters Size [Elements]': '-', 'MACs': '-' } columns = [key for key in all_8_dict] with open( args.evaluate + '.per_layer_accuracy.A' + str(args.nbits_act) + '.W' + str(args.nbits_weight) + '.csv', "w") as f: f.write(",".join(columns) + "\n") col = [str(all_8_dict[c]) for c in all_8_dict.keys()] f.write(",".join(col) + "\n") for layer in per_layer_results: r = per_layer_results[layer] col = [str(r[c]) for c in r.keys()] f.write(",".join(col) + "\n") elif args.mixed_builder: if isinstance(args.names_sp_layers, list): print('loading int8 model" ', args.int8_opt_model_path) int8_opt_model_state_dict = torch.load(args.int8_opt_model_path) print('loading int4 model" ', args.int4_opt_model_path) int4_opt_model_state_dict = torch.load(args.int4_opt_model_path) model.load_state_dict(int8_opt_model_state_dict, strict=False) model = search_replace_layer(model, args.names_sp_layers, num_bits_activation=args.nbits_act, num_bits_weight=args.nbits_weight) for layer_idx, layer in enumerate(args.names_sp_layers): layer_keys = [ key for key in int8_opt_model_state_dict for qpkey in quant_keys if layer + qpkey == key ] for key in layer_keys: model.state_dict()[key].copy_( int4_opt_model_state_dict[key]) print('switched layer %s to 4 bit' % (layer)) elif isinstance(args.names_sp_layers, dict): quant_models = {} base_precision = args.precisions[0] for m, prec in zip(args.opt_model_paths, args.precisions): print('For precision={}, loading {}'.format(prec, m)) quant_models[prec] = torch.load(m) model.load_state_dict(quant_models[base_precision], strict=False) for layer_name, nbits_list in args.names_sp_layers.items(): model = search_replace_layer(model, [layer_name], num_bits_activation=nbits_list[0], num_bits_weight=nbits_list[0]) layer_keys = [ key for key in quant_models[base_precision] for qpkey in quant_keys if layer_name + qpkey == key ] for key in layer_keys: model.state_dict()[key].copy_( quant_models[nbits_list[0]][key]) print('switched layer {} to {} bit'.format( layer_name, nbits_list[0])) if os.environ.get('DEBUG') == 'True': from utils.layer_sensativity import check_quantized_model fp_names = check_quantized_model(trainer.model) if len(fp_names) > 0: logging.info('Found FP32 layers in the model:') logging.info(fp_names) if args.eval_on_train: mixedIP_results = trainer.validate(train_data.get_loader()) else: mixedIP_results = trainer.validate(val_data.get_loader()) torch.save( { 'state_dict': model.state_dict(), 'config-ip': args.names_sp_layers }, args.evaluate + '.mixed-ip-results.' + args.suffix) logging.info(mixedIP_results) acc = mixedIP_results['prec1'] loss = mixedIP_results['loss'] elif args.batch_norn_tuning: from utils.layer_sensativity import search_replace_layer, extract_save_quant_state_dict, search_replace_layer_from_dict from models.modules.quantize import QConv2d if args.layers_precision_dict is not None: model = search_replace_layer_from_dict( model, literal_eval(args.layers_precision_dict)) else: model = search_replace_layer(model, args.names_sp_layers, num_bits_activation=args.nbits_act, num_bits_weight=args.nbits_weight) exec_lfv_str = 'torchvision.models.' + args.load_from_vision + '(pretrained=True)' model_orig = eval(exec_lfv_str) model_orig.to(args.device, dtype) search_copy_bn_params(model_orig) layers_orig = dict([(n, m) for n, m in model_orig.named_modules() if isinstance(m, nn.Conv2d)]) layers_q = dict([(n, m) for n, m in model.named_modules() if isinstance(m, QConv2d)]) for l in layers_orig: conv_orig = layers_orig[l] conv_q = layers_q[l] conv_q.register_parameter('gamma', nn.Parameter(conv_orig.gamma.clone())) conv_q.register_parameter('beta', nn.Parameter(conv_orig.beta.clone())) del model_orig search_add_bn(model) print("Run BN tuning") for tt in range(args.tuning_iter): print(tt) trainer.cal_bn_stats(train_data.get_loader()) search_absorbe_tuning_bn(model) filename = args.evaluate + '.bn_tuning' print("Save model to: {}".format(filename)) torch.save(model.state_dict(), filename) val_results = trainer.validate(val_data.get_loader()) logging.info(val_results) if args.res_log is not None: if not os.path.exists(args.res_log): df = pd.DataFrame() else: df = pd.read_csv(args.res_log, index_col=0) ckp = ntpath.basename(args.evaluate) df.loc[ckp, 'acc_bn_tuning'] = val_results['prec1'] df.loc[ckp, 'loss_bn_tuning'] = val_results['loss'] df.to_csv(args.res_log) # print(df) elif args.bias_tuning: for epoch in range(args.epochs): trainer.epoch = epoch train_data.set_epoch(epoch) val_data.set_epoch(epoch) logging.info('\nStarting Epoch: {0}\n'.format(epoch + 1)) # train for one epoch repeat_train = 20 if args.update_only_th else 1 for tt in range(repeat_train): print(tt) train_results = trainer.train( train_data.get_loader(), duplicates=train_data.get('duplicates'), chunk_batch=args.chunk_batch) logging.info(train_results) val_results = trainer.validate(val_data.get_loader()) logging.info(val_results) if args.res_log is not None: if not os.path.exists(args.res_log): df = pd.DataFrame() else: df = pd.read_csv(args.res_log, index_col=0) ckp = ntpath.basename(args.evaluate) if 'bn_tuning' in ckp: ckp = ckp.replace('.bn_tuning', '') df.loc[ckp, 'acc_bias_tuning'] = val_results['prec1'] df.to_csv(args.res_log) # import pdb; pdb.set_trace() else: #print('Please Choose one of the following ....') if model_config['measure']: results = trainer.validate(train_data.get_loader(), rec=args.rec) # results = trainer.validate(val_data.get_loader()) # print(results) else: if args.evaluate_init_configuration: results = trainer.validate(val_data.get_loader()) if args.res_log is not None: if not os.path.exists(args.res_log): df = pd.DataFrame() else: df = pd.read_csv(args.res_log, index_col=0) ckp = ntpath.basename(args.evaluate) if args.cmp is not None: ckp += '_{}'.format(args.cmp) df.loc[ckp, 'acc_base'] = results['prec1'] df.to_csv(args.res_log) if args.extract_bias_mean: file_name = 'bias_mean_measure' if model_config[ 'measure'] else 'bias_mean_quant' torch.save(trainer.bias_mean, file_name) if model_config['measure']: filename = args.evaluate + '.measure' if 'perC' in args.model_config: filename += '_perC' torch.save(model.state_dict(), filename) logging.info(results) else: if args.evaluate_init_configuration: logging.info(results) return acc, loss