def run_evaluation(args, model, data_loaders, model_description, n_choices, layers_types, downsample_layers): start = time.time() num_samples = utils.get_number_of_samples(args.dataset) all_values = {} device = 'cuda' #setting up random seeds utils.setup_torch(args.seed) #creating model skeleton based on description propagate_weights = [] for layer in model_description: cur_weights = [0 for i in range(n_choices)] cur_weights[layers_types.index(layer)] = 1 propagate_weights.append(cur_weights) model.propagate = propagate_weights #Create the computationally identical model but without multiple choice blocks (just a single path net) #This is needed to correctly measure MACs pruned_model = models.SinglePathSupernet( num_classes=utils.get_number_of_classes(args.dataset), propagate=propagate_weights, put_downsampling=downsample_layers) #.to(device) pruned_model.propagate = propagate_weights inputs = torch.randn((1, 3, 32, 32)) total_ops, total_params = profile(pruned_model, (inputs, ), verbose=True) all_values['MMACs'] = np.round(total_ops / (1000.0**2), 2) all_values['Params'] = int(total_params) del pruned_model del inputs ################################################ criterion = torch.nn.CrossEntropyLoss() #Initialize batch normalization parameters utils.bn_update(device, data_loaders['train_for_bn_recalc'], model) val_res = utils.evaluate(device, data_loaders['val'], model, criterion, num_samples['val']) test_res = utils.evaluate(device, data_loaders['test'], model, criterion, num_samples['test']) all_values['val_loss'] = np.round(val_res['loss'], 3) all_values['val_acc'] = np.round(val_res['accuracy'], 3) all_values['test_loss'] = np.round(test_res['loss'], 3) all_values['test_acc'] = np.round(test_res['accuracy'], 3) print(all_values, 'time taken: %.2f sec.' % (time.time() - start)) utils.save_result(all_values, args.dir, model_description)
def swa_train(model, swa_model, train_iter, valid_iter, optimizer, criterion, pretrain_epochs, swa_epochs, swa_lr, cycle_length, device, writer, cpt_filename): swa_n = 1 swa_model.load_state_dict(copy.deepcopy(model.state_dict())) utils.save_checkpoint( cpt_directory, 1, '{}-swa-{:2.4f}-{:03d}-{}'.format(date, swa_lr, cycle_length, cpt_filename), state_dict=model.state_dict(), swa_state_dict=swa_model.state_dict(), swa_n=swa_n, optimizer=optimizer.state_dict() ) for e in range(swa_epochs): epoch = e + pretrain_epochs time_ep = time.time() lr = utils.schedule(epoch, cycle_length, lr_init, swa_lr) utils.adjust_learning_rate(optimizer, lr) train_res = utils.train_epoch(model, train_iter, optimizer, criterion, device) valid_res = utils.evaluate(model, valid_iter, criterion, device) utils.moving_average(swa_model, model, swa_n) swa_n += 1 utils.bn_update(train_iter, swa_model) swa_res = utils.evaluate(swa_model, valid_iter, criterion, device) time_ep = time.time() - time_ep values = [epoch + 1, lr, swa_lr, cycle_length, train_res['loss'], valid_res['loss'], swa_res['loss'], None, None, time_ep] writer.writerow(values) table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') if epoch % 20 == 0: table = table.split('\n') table = '\n'.join([table[1]] + table) else: table = table.split('\n')[2] print(table) utils.save_checkpoint( cpt_directory, epoch + 1, '{}-swa-{:2.4f}-{:03d}-{}'.format(date, swa_lr, cycle_length, cpt_filename), state_dict=model.state_dict(), swa_state_dict=swa_model.state_dict(), swa_n=swa_n, optimizer=optimizer.state_dict() )
def evaluate(self, torch_state_dict, update_buffers=False): criterion = F.cross_entropy # Recover model from state_dict self.model.load_state_dict(torch_state_dict) # Update BatchNorm buffers (if any) if update_buffers: utils.bn_update(self.loaders['train'], self.model) # Evalute on the training and test sets train_res = utils.eval(self.loaders['train'], self.model, criterion) test_res = utils.eval(self.loaders['test'], self.model, criterion) return train_res, test_res
lr = schedule(epoch) utils.adjust_learning_rate(optimizer, lr) train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer) if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: test_res = utils.eval(loaders['test'], model, criterion) else: test_res = {'loss': None, 'accuracy': None} if args.swa and (epoch + 1) >= args.swa_start and ( epoch + 1 - args.swa_start) % args.swa_c_epochs == 0: utils.moving_average(swa_model, model, 1.0 / (swa_n + 1)) swa_n += 1 if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: utils.bn_update(loaders['train'], swa_model) swa_res = utils.eval(loaders['test'], swa_model, criterion) else: swa_res = {'loss': None, 'accuracy': None} if (epoch + 1) % args.save_freq == 0: utils.save_checkpoint( args.dir, epoch + 1, state_dict=model.state_dict(), swa_state_dict=swa_model.state_dict() if args.swa else None, swa_n=swa_n if args.swa else None, optimizer=optimizer.state_dict()) time_ep = time.time() - time_ep values = [
def train_main(cfg): ''' 训练的主函数 :param cfg: 配置 :return: ''' # config train_cfg = cfg.train_cfg dataset_cfg = cfg.dataset_cfg model_cfg = cfg.model_cfg is_parallel = cfg.setdefault(key='is_parallel', default=False) device = cfg.device is_online_train = cfg.setdefault(key='is_online_train', default=False) # 配置logger logging.basicConfig(filename=cfg.logfile, filemode='a', level=logging.INFO, format='%(asctime)s\n%(message)s', datefmt='%Y-%m-%d %H:%M:%S') logger = logging.getLogger() # # 构建数据集 train_dataset = LandDataset(DIR_list=dataset_cfg.train_dir_list, mode='train', input_channel=dataset_cfg.input_channel, transform=dataset_cfg.train_transform) split_val_from_train_ratio = dataset_cfg.setdefault( key='split_val_from_train_ratio', default=None) if split_val_from_train_ratio is None: val_dataset = LandDataset(DIR_list=dataset_cfg.val_dir_list, mode='val', input_channel=dataset_cfg.input_channel, transform=dataset_cfg.val_transform) else: val_size = int(len(train_dataset) * split_val_from_train_ratio) train_size = len(train_dataset) - val_size train_dataset, val_dataset = random_split( train_dataset, [train_size, val_size], generator=torch.manual_seed(cfg.random_seed)) # val_dataset.dataset.transform = dataset_cfg.val_transform # 要配置一下val的transform print(f"按照{split_val_from_train_ratio}切分训练集...") # 构建dataloader def _init_fn(): np.random.seed(cfg.random_seed) train_dataloader = DataLoader(train_dataset, batch_size=train_cfg.batch_size, shuffle=True, num_workers=train_cfg.num_workers, drop_last=True, worker_init_fn=_init_fn()) val_dataloader = DataLoader(val_dataset, batch_size=train_cfg.batch_size, num_workers=train_cfg.num_workers, shuffle=False, drop_last=True, worker_init_fn=_init_fn()) # 构建模型 if train_cfg.is_swa: model = torch.load(train_cfg.check_point_file, map_location=device).to( device) # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上 swa_model = torch.load( train_cfg.check_point_file, map_location=device).to( device) # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上 if is_parallel: model = torch.nn.DataParallel(model) swa_model = torch.nn.DataParallel(swa_model) swa_n = 0 parameters = swa_model.parameters() else: model = build_model(model_cfg).to(device) if is_parallel: model = torch.nn.DataParallel(model) parameters = model.parameters() # 定义优化器 optimizer_cfg = train_cfg.optimizer_cfg lr_scheduler_cfg = train_cfg.lr_scheduler_cfg if optimizer_cfg.type == 'adam': optimizer = optim.Adam(params=parameters, lr=optimizer_cfg.lr, weight_decay=optimizer_cfg.weight_decay) elif optimizer_cfg.type == 'adamw': optimizer = optim.AdamW(params=parameters, lr=optimizer_cfg.lr, weight_decay=optimizer_cfg.weight_decay) elif optimizer_cfg.type == 'sgd': optimizer = optim.SGD(params=parameters, lr=optimizer_cfg.lr, momentum=optimizer_cfg.momentum, weight_decay=optimizer_cfg.weight_decay) elif optimizer_cfg.type == 'RMS': optimizer = optim.RMSprop(params=parameters, lr=optimizer_cfg.lr, weight_decay=optimizer_cfg.weight_decay) else: raise Exception('没有该优化器!') if not lr_scheduler_cfg: lr_scheduler = None elif lr_scheduler_cfg.policy == 'cos': lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, lr_scheduler_cfg.T_0, lr_scheduler_cfg.T_mult, lr_scheduler_cfg.eta_min, last_epoch=lr_scheduler_cfg.last_epoch) elif lr_scheduler_cfg.policy == 'LambdaLR': import math lf = lambda x: (((1 + math.cos(x * math.pi / train_cfg.num_epochs)) / 2 )**1.0) * 0.95 + 0.05 # cosine lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) lr_scheduler.last_epoch = 0 else: lr_scheduler = None # 定义损失函数 DiceLoss_fn = DiceLoss(mode='multiclass') SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1) loss_func = L.JointLoss(first=DiceLoss_fn, second=SoftCrossEntropy_fn, first_weight=0.5, second_weight=0.5).cuda() # loss_cls_func = torch.nn.BCEWithLogitsLoss() # 创建保存模型的文件夹 check_point_dir = '/'.join(model_cfg.check_point_file.split('/')[:-1]) if not os.path.exists(check_point_dir): # 如果文件夹不存在就创建 os.mkdir(check_point_dir) # 开始训练 auto_save_epoch_list = train_cfg.setdefault(key='auto_save_epoch_list', default=5) # 每隔几轮保存一次模型,默认为5 train_loss_list = [] val_loss_list = [] val_loss_min = 999999 best_epoch = 0 best_miou = 0 train_loss = 10 # 设置一个初始值 logger.info('开始在{}上训练{}模型...'.format(device, model_cfg.type)) logger.info('补充信息:{}\n'.format(cfg.setdefault(key='info', default='None'))) for epoch in range(train_cfg.num_epochs): print() print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) start_time = time.time() print(f"正在进行第{epoch}轮训练...") logger.info('*' * 10 + f"第{epoch}轮" + '*' * 10) # # 训练一轮 if train_cfg.is_swa: # swa训练方式 train_loss = train_epoch(swa_model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device) moving_average(model, swa_model, 1.0 / (swa_n + 1)) swa_n += 1 bn_update(train_dataloader, model, device) else: train_loss = train_epoch(model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device) # train_loss = train_unet3p_epoch(model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device) # # 在训练集上评估模型 # val_loss, val_miou = evaluate_unet3p_model(model, val_dataset, loss_func, device, # cfg.num_classes, train_cfg.num_workers, batch_size=train_cfg.batch_size) if not is_online_train: # 只有在线下训练的时候才需要评估模型 val_loss, val_miou = evaluate_model(model, val_dataloader, loss_func, device, cfg.num_classes) else: val_loss = 0 val_miou = 0 train_loss_list.append(train_loss) val_loss_list.append(val_loss) # 保存模型 if not is_online_train: # 非线上训练时需要保存best model if val_loss < val_loss_min: val_loss_min = val_loss best_epoch = epoch best_miou = val_miou if is_parallel: torch.save(model.module, model_cfg.check_point_file) else: torch.save(model, model_cfg.check_point_file) if epoch in auto_save_epoch_list: # 如果再需要保存的轮次中,则保存 model_file = model_cfg.check_point_file.split( '.pth')[0] + '-epoch{}.pth'.format(epoch) if is_parallel: torch.save(model.module, model_file) else: torch.save(model, model_file) # 打印中间结果 end_time = time.time() run_time = int(end_time - start_time) m, s = divmod(run_time, 60) time_str = "{:02d}分{:02d}秒".format(m, s) print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) out_str = "第{}轮训练完成,耗时{},\t训练集上的loss={:.6f};\t验证集上的loss={:.4f},mIoU={:.6f}\t最好的结果是第{}轮,mIoU={:.6f}" \ .format(epoch, time_str, train_loss, val_loss, val_miou, best_epoch, best_miou) # out_str = "第{}轮训练完成,耗时{},\n训练集上的segm_loss={:.6f},cls_loss{:.6f}\n验证集上的segm_loss={:.4f},cls_loss={:.4f},mIoU={:.6f}\n最好的结果是第{}轮,mIoU={:.6f}" \ # .format(epoch, time_str, train_loss, train_cls_loss, val_loss, val_cls_loss, val_miou, best_epoch, # best_miou) print(out_str) logger.info(out_str + '\n')
momentum=args.momentum, weight_decay=args.wd) start_epoch = 0 columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time'] for epoch in range(start_epoch, args.epochs): time_ep = time.time() lr = schedule(epoch) utils.adjust_learning_rate(optimizer, lr) train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer) if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: utils.bn_update(loaders['train'], model, use_half=True, use_cuda=True) test_res = utils.eval(loaders['test'], model, criterion, use_half=True, use_cuda=True) else: test_res = {'loss': None, 'accuracy': None} time_ep = time.time() - time_ep values = [ epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'], time_ep ] table = tabulate.tabulate([values], columns,
criterion = F.cross_entropy # optimizer = torch.optim.SGD( # model_temp.parameters(), # lr=args.lr_init, # momentum=args.momentum, # weight_decay=args.wd # ) start_epoch = 0 if args.model1_resume is not None: print('Resume training from %s' % args.model1_resume) checkpoint = torch.load(args.model1_resume) start_epoch = checkpoint['epoch'] model_1.load_state_dict(checkpoint['state_dict']) utils.bn_update(loaders['train'], model_1) print(utils.eval(loaders['train'], model_1, criterion)) vec_1 = parameters_to_vector(model_1.parameters()) if args.model2_resume is not None: print('Resume training from %s' % args.model2_resume) checkpoint = torch.load(args.model2_resume) start_epoch = checkpoint['epoch'] model_2.load_state_dict(checkpoint['swa_state_dict']) model_temp.load_state_dict(checkpoint['state_dict']) utils.bn_update(loaders['train'], model_2) print(utils.eval(loaders['train'], model_2, criterion)) vec_2 = parameters_to_vector(model_2.parameters()) vec_inter = vec_1 - vec_2 # vec_inter_norm = torch.norm(vec_inter)
def main(): script_dir = os.path.dirname(__file__) module_path = os.path.abspath(os.path.join(script_dir, '..', '..')) global msglogger # Parse arguments args = parser.get_parser().parse_args() if args.epochs is None: args.epochs = 90 if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) msglogger = apputils.config_pylogger( os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir) # Log various details about the execution environment. It is sometimes useful # to refer to past experiment executions and this information may be useful. apputils.log_execution_env_state(args.compress, msglogger.logdir, gitroot=module_path) msglogger.debug("Distiller: %s", distiller.__version__) start_epoch = 0 ending_epoch = args.epochs perf_scores_history = [] if args.evaluate: args.deterministic = True if args.deterministic: # Experiment reproducibility is sometimes important. Pete Warden expounded about this # in his blog: https://petewarden.com/2018/03/19/the-machine-learning-reproducibility-crisis/ distiller.set_deterministic( ) # Use a well-known seed, for repeatability of experiments else: # Turn on CUDNN benchmark mode for best performance. This is usually "safe" for image # classification models, as the input sizes don't change during the run # See here: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 cudnn.benchmark = True if args.cpu or not torch.cuda.is_available(): # Set GPU index to -1 if using CPU args.device = 'cpu' args.gpus = -1 else: args.device = 'cuda' if args.gpus is not None: try: args.gpus = [int(s) for s in args.gpus.split(',')] except ValueError: raise ValueError( 'ERROR: Argument --gpus must be a comma-separated list of integers only' ) available_gpus = torch.cuda.device_count() for dev_id in args.gpus: if dev_id >= available_gpus: raise ValueError( 'ERROR: GPU device ID {0} requested, but only {1} devices available' .format(dev_id, available_gpus)) # Set default device in case the first one on the list != 0 torch.cuda.set_device(args.gpus[0]) # Infer the dataset from the model name args.dataset = 'cifar10' if 'cifar' in args.arch else 'imagenet' args.num_classes = 10 if args.dataset == 'cifar10' else 1000 # Create the model model = create_model(args.pretrained, args.dataset, args.arch, parallel=not args.load_serialized, device_ids=args.gpus) if args.swa: swa_model = create_model(args.pretrained, args.dataset, args.arch, parallel=not args.load_serialized, device_ids=args.gpus) swa_n = 0 compression_scheduler = None # Create a couple of logging backends. TensorBoardLogger writes log files in a format # that can be read by Google's Tensor Board. PythonLogger writes to the Python logger. tflogger = TensorBoardLogger(msglogger.logdir) pylogger = PythonLogger(msglogger) # TODO(barrh): args.deprecated_resume is deprecated since v0.3.1 if args.deprecated_resume: msglogger.warning( 'The "--resume" flag is deprecated. Please use "--resume-from=YOUR_PATH" instead.' ) if not args.reset_optimizer: msglogger.warning( 'If you wish to also reset the optimizer, call with: --reset-optimizer' ) args.reset_optimizer = True args.resumed_checkpoint_path = args.deprecated_resume # We can optionally resume from a checkpoint optimizer = None # TODO: resume from swa mode if args.resumed_checkpoint_path: if args.swa: model, swa_model, swa_n, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint( model, args.resumed_checkpoint_path, swa_model=swa_model, swa_n=swa_n, model_device=args.device) else: model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint( model, args.resumed_checkpoint_path, model_device=args.device) elif args.load_model_path: model = apputils.load_lean_checkpoint(model, args.load_model_path, model_device=args.device) if args.reset_optimizer: start_epoch = 0 if optimizer is not None: optimizer = None msglogger.info( '\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0' ) # Define loss function (criterion) criterion = nn.CrossEntropyLoss().to(args.device) if optimizer is None: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) msglogger.info('Optimizer Type: %s', type(optimizer)) msglogger.info('Optimizer Args: %s', optimizer.defaults) # This sample application can be invoked to produce various summary reports. if args.summary: return summarize_model(model, args.dataset, which_summary=args.summary) activations_collectors = create_activation_stats_collectors( model, *args.activation_stats) # Load the datasets: the dataset to load is inferred from the model name passed # in args.arch. The default dataset is ImageNet, but if args.arch contains the # substring "_cifar", then cifar10 is used. train_loader, val_loader, test_loader, _ = apputils.load_data( args.dataset, os.path.expanduser(args.data), args.batch_size, args.workers, args.validation_split, args.deterministic, args.effective_train_size, args.effective_valid_size, args.effective_test_size) msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d', len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler)) if args.sensitivity is not None: sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2]) return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities) if args.evaluate: return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args, compression_scheduler) if args.compress: # The main use-case for this sample application is CNN compression. Compression # requires a compression schedule configuration file in YAML. compression_scheduler = distiller.file_config( model, optimizer, args.compress, compression_scheduler, (start_epoch - 1) if args.resumed_checkpoint_path else None) # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer) model.to(args.device) elif compression_scheduler is None: compression_scheduler = distiller.CompressionScheduler(model) if args.thinnify: #zeros_mask_dict = distiller.create_model_masks_dict(model) assert args.resumed_checkpoint_path is not None, \ "You must use --resume-from to provide a checkpoint file to thinnify" distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None) apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler, name="{}_thinned".format( args.resumed_checkpoint_path.replace( ".pth.tar", "")), dir=msglogger.logdir) print( "Note: your model may have collapsed to random inference, so you may want to fine-tune" ) return if args.lr_find: lr_finder = distiller.LRFinder(model, optimizer, criterion, device=args.device) lr_finder.range_test(train_loader, end_lr=10, num_iter=100) lr_finder.plot() return if start_epoch >= ending_epoch: msglogger.error( 'epoch count is too low, starting epoch is {} but total epochs set to {}' .format(start_epoch, ending_epoch)) raise ValueError('Epochs parameter is too low. Nothing to do.') for epoch in range(start_epoch, ending_epoch): # This is the main training loop. msglogger.info('\n') if compression_scheduler: compression_scheduler.on_epoch_begin( epoch, metrics=(vloss if (epoch != start_epoch) else 10**6)) # Train for one epoch with collectors_context(activations_collectors["train"]) as collectors: train(train_loader, model, criterion, optimizer, epoch, compression_scheduler, loggers=[tflogger, pylogger], args=args) # distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger]) # distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger], # collector=collectors["sparsity"]) if args.masks_sparsity: msglogger.info( distiller.masks_sparsity_tbl_summary( model, compression_scheduler)) # evaluate on validation set with collectors_context(activations_collectors["valid"]) as collectors: top1, top5, vloss = validate(val_loader, model, criterion, [pylogger], args, epoch) msglogger.info('==> Top1: %.3f Top5: %.3f Loss: %.3f\n', top1, top5, vloss) distiller.log_activation_statsitics( epoch, "valid", loggers=[tflogger], collector=collectors["sparsity"]) save_collectors_data(collectors, msglogger.logdir) stats = ('Performance/Validation/', OrderedDict([('Loss', vloss), ('Top1', top1), ('Top5', top5)])) if args.swa and (epoch + 1) >= args.swa_start and ( epoch + 1 - args.swa_start ) % args.swa_freq == 0 or epoch == ending_epoch - 1: utils.moving_average(swa_model, model, 1. / (swa_n + 1)) swa_n += 1 utils.bn_update(train_loader, swa_model, args) swa_top1, swa_top5, swa_loss = validate(val_loader, swa_model, criterion, [pylogger], args, epoch) msglogger.info( '==> SWA_Top1: %.3f SWA_Top5: %.3f SWA_Loss: %.3f\n', swa_top1, swa_top5, swa_loss) swa_res = OrderedDict([('SWA_Loss', swa_loss), ('SWA_Top1', swa_top1), ('SWA_Top5', swa_top5)]) stats[1].update(swa_res) distiller.log_training_progress(stats, None, epoch, steps_completed=0, total_steps=1, log_freq=1, loggers=[tflogger]) if compression_scheduler: compression_scheduler.on_epoch_end(epoch, optimizer) # Update the list of top scores achieved so far, and save the checkpoint update_training_scores_history(perf_scores_history, model, top1, top5, epoch, args.num_best_scores) is_best = epoch == perf_scores_history[0].epoch checkpoint_extras = { 'current_top1': top1, 'best_top1': perf_scores_history[0].top1, 'best_epoch': perf_scores_history[0].epoch } if args.swa: apputils.save_checkpoint(epoch, args.arch, model, swa_model, swa_n, optimizer=optimizer, scheduler=compression_scheduler, extras=checkpoint_extras, is_best=is_best, name=args.name, dir=msglogger.logdir) else: apputils.save_checkpoint(epoch, args.arch, model, optimizer=optimizer, scheduler=compression_scheduler, extras=checkpoint_extras, is_best=is_best, name=args.name, dir=msglogger.logdir) # Finally run results on the test set test(test_loader, model, criterion, [pylogger], activations_collectors, args=args) if args.swa: test(test_loader, swa_model, criterion, [pylogger], activations_collectors, args=args)
def run_evaluation(model, ensemble_model, data_loaders, args, save_model='', load_model=''): all_values = {} device = 'cuda' utils.setup_torch(args['seed']) inputs = torch.randn( (1, args['input_channels'], args['img_size'], args['img_size'])) total_ops, total_params = profile(model, (inputs, ), verbose=True) all_values['MMACs'] = np.round(total_ops / (1000.0**2), 2) all_values['Params'] = int(total_params) print(all_values) start = time.time() model = model.to(device) ensemble_model = ensemble_model.to(device) print('models to device', time.time() - start) if len(load_model) > 0: model.load_state_dict(torch.load(os.path.join(args['dir'], load_model))) criterion = torch.nn.CrossEntropyLoss() ################################################ summary(model, (3, 32, 32), batch_size=args['batch_size'], device='cuda') criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(model.parameters(), lr=args['lr_init'], momentum=0.9, weight_decay=1e-4) lrs = [] n_models = 0 all_values['epoch'] = [] all_values['overall_time'] = [] all_values['lr'] = [] all_values['tr_loss'] = [] all_values['tr_acc'] = [] all_values['val_loss_single'] = [] all_values['val_acc_single'] = [] all_values['val_loss_ensemble'] = [] all_values['val_acc_ensemble'] = [] all_values['test_loss_single'] = [] all_values['test_acc_single'] = [] all_values['test_loss_ensemble'] = [] all_values['test_acc_ensemble'] = [] n_models = 0 time_start = time.time() for epoch in range(args['epochs']): time_ep = time.time() lr = utils.get_cyclic_lr(epoch, lrs, args['lr_init'], args['lr_start_cycle'], args['cycle_period']) #print ('lr=%.3f' % lr) utils.set_learning_rate(optimizer, lr) lrs.append(lr) train_res = utils.train_epoch(device, data_loaders['train'], model, criterion, optimizer, args['num_samples_train']) values = [epoch + 1, lr, train_res['loss'], train_res['accuracy']] if (epoch + 1) >= args['lr_start_cycle'] and ( epoch + 1) % args['cycle_period'] == 0: all_values['epoch'].append(epoch + 1) all_values['lr'].append(lr) all_values['tr_loss'].append(train_res['loss']) all_values['tr_acc'].append(train_res['accuracy']) val_res = utils.evaluate(device, data_loaders['val'], model, criterion, args['num_samples_val']) test_res = utils.evaluate(device, data_loaders['test'], model, criterion, args['num_samples_test']) all_values['val_loss_single'].append(val_res['loss']) all_values['val_acc_single'].append(val_res['accuracy']) all_values['test_loss_single'].append(test_res['loss']) all_values['test_acc_single'].append(test_res['accuracy']) utils.moving_average_ensemble(ensemble_model, model, 1.0 / (n_models + 1)) utils.bn_update(device, data_loaders['train_for_bn_recalc'], ensemble_model) n_models += 1 val_res = utils.evaluate(device, data_loaders['val'], ensemble_model, criterion, args['num_samples_val']) test_res = utils.evaluate(device, data_loaders['test'], ensemble_model, criterion, args['num_samples_test']) all_values['val_loss_ensemble'].append(val_res['loss']) all_values['val_acc_ensemble'].append(val_res['accuracy']) all_values['test_loss_ensemble'].append(test_res['loss']) all_values['test_acc_ensemble'].append(test_res['accuracy']) overall_training_time = time.time() - time_start all_values['overall_time'].append(overall_training_time) #print (epoch, 'epoch_time', time.time() - time_ep) overall_training_time = time.time() - time_start #print ('overall time', overall_training_time) #print (all_values) if len(save_model) > 0: torch.save(ensemble_model.state_dict(), os.path.join(args['dir'], save_model + '_ensemble')) torch.save(model.state_dict(), os.path.join(args['dir'], save_model)) return all_values
def train_oneshot_model(args, data_loaders, n_cells, n_choices, put_downsampling=[]): num_samples = utils.get_number_of_samples(args.dataset) device = 'cuda' utils.setup_torch(args.seed) print('Initializing model...') #Create a supernet skeleton (include all cell types for each position) propagate_weights = [[1, 1, 1] for i in range(n_cells)] model_class = getattr(models, 'Supernet') #Create the supernet model and its SWA ensemble version model = model_class(num_classes=utils.get_number_of_classes(args.dataset), propagate=propagate_weights, training=True, n_choices=n_choices, put_downsampling=put_downsampling).to(device) ensemble_model = model_class(num_classes=utils.get_number_of_classes( args.dataset), propagate=propagate_weights, training=True, n_choices=n_choices, put_downsampling=put_downsampling).to(device) #These summaries are for verification purposes only #However, removing them will cause inconsistency in results since random generators are used inside them to propagate summary(model, (3, 32, 32), batch_size=args.batch_size, device='cuda') summary(ensemble_model, (3, 32, 32), batch_size=args.batch_size, device='cuda') criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_init, momentum=0.9, weight_decay=1e-4) start_epoch = 0 columns = [ 'epoch time', 'overall training time', 'epoch', 'lr', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'test_loss', 'test_acc' ] lrs = [] n_models = 0 all_values = {} all_values['epoch'] = [] all_values['lr'] = [] all_values['tr_loss'] = [] all_values['tr_acc'] = [] all_values['val_loss'] = [] all_values['val_acc'] = [] all_values['test_loss'] = [] all_values['test_acc'] = [] n_models = 0 print('Start training...') time_start = time.time() for epoch in range(start_epoch, args.epochs): time_ep = time.time() #lr = utils.get_cosine_annealing_lr(epoch, args.lr_init, args.epochs) lr = utils.get_cyclic_lr(epoch, lrs, args.lr_init, args.lr_start_cycle, args.cycle_period) utils.set_learning_rate(optimizer, lr) lrs.append(lr) train_res = utils.train_epoch(device, data_loaders['train'], model, criterion, optimizer, num_samples['train']) values = [epoch + 1, lr, train_res['loss'], train_res['accuracy']] if (epoch + 1) >= args.lr_start_cycle and (epoch + 1) % args.cycle_period == 0: all_values['epoch'].append(epoch + 1) all_values['lr'].append(lr) all_values['tr_loss'].append(train_res['loss']) all_values['tr_acc'].append(train_res['accuracy']) val_res = utils.evaluate(device, data_loaders['val'], model, criterion, num_samples['val']) test_res = utils.evaluate(device, data_loaders['test'], model, criterion, num_samples['test']) all_values['val_loss'].append(val_res['loss']) all_values['val_acc'].append(val_res['accuracy']) all_values['test_loss'].append(test_res['loss']) all_values['test_acc'].append(test_res['accuracy']) values += [ val_res['loss'], val_res['accuracy'], test_res['loss'], test_res['accuracy'] ] utils.moving_average_ensemble(ensemble_model, model, 1.0 / (n_models + 1)) utils.bn_update(device, data_loaders['train'], ensemble_model) n_models += 1 print(all_values) overall_training_time = time.time() - time_start values = [time.time() - time_ep, overall_training_time] + values table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') print(table) print('Training finished. Saving final nets...') utils.save_result(all_values, args.dir, 'model_supernet') torch.save(model.state_dict(), args.dir + '/supernet.pth') torch.save(ensemble_model.state_dict(), args.dir + '/supernet_swa.pth')
def main(): ds = getattr(torchvision.datasets, args.dataset) path = os.path.join(args.data_path, args.dataset.lower()) train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train) test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test) loaders = { 'train': torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True), 'test': torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) } num_classes = len(train_set.classes) # max(train_set.train_labels) + 1 print(num_classes) print('Preparing model') model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) model.cuda() if args.swa: print('SWA training') swa_model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) swa_model.cuda() swa_n = 0 else: print('SGD training') def schedule(epoch): t = (epoch) / (args.swa_start if args.swa else args.epochs) lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01 if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else: factor = lr_ratio return args.lr_init * factor criterion = F.cross_entropy optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_init, momentum=args.momentum, weight_decay=args.wd) start_epoch = 0 if args.resume is not None: print('Resume training from %s' % args.resume) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) if args.swa: swa_state_dict = checkpoint['swa_state_dict'] if swa_state_dict is not None: swa_model.load_state_dict(swa_state_dict) swa_n_ckpt = checkpoint['swa_n'] if swa_n_ckpt is not None: swa_n = swa_n_ckpt columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time'] if args.swa: columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + columns[-1:] swa_res = {'loss': None, 'accuracy': None} utils.save_checkpoint( args.dir, start_epoch, state_dict=model.state_dict(), swa_state_dict=swa_model.state_dict() if args.swa else None, swa_n=swa_n if args.swa else None, optimizer=optimizer.state_dict()) for epoch in range(start_epoch, args.epochs): time_ep = time.time() lr = schedule(epoch) utils.adjust_learning_rate(optimizer, lr) train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer) if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: test_res = utils.eval(loaders['test'], model, criterion) else: test_res = {'loss': None, 'accuracy': None} if args.swa and (epoch + 1) >= args.swa_start and ( epoch + 1 - args.swa_start) % args.swa_c_epochs == 0: utils.moving_average(swa_model, model, 1.0 / (swa_n + 1)) swa_n += 1 if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: utils.bn_update(loaders['train'], swa_model) swa_res = utils.eval(loaders['test'], swa_model, criterion) else: swa_res = {'loss': None, 'accuracy': None} if (epoch + 1) % args.save_freq == 0: utils.save_checkpoint( args.dir, epoch + 1, state_dict=model.state_dict(), swa_state_dict=swa_model.state_dict() if args.swa else None, swa_n=swa_n if args.swa else None, optimizer=optimizer.state_dict()) time_ep = time.time() - time_ep values = [ epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'], time_ep ] if args.swa: values = values[:-1] + [swa_res['loss'], swa_res['accuracy'] ] + values[-1:] table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') if epoch % 40 == 0: table = table.split('\n') table = '\n'.join([table[1]] + table) else: table = table.split('\n')[2] print(table) if args.epochs % args.save_freq != 0: utils.save_checkpoint( args.dir, args.epochs, state_dict=model.state_dict(), swa_state_dict=swa_model.state_dict() if args.swa else None, swa_n=swa_n if args.swa else None, optimizer=optimizer.state_dict())