def main(): """Entry.""" # init distributed global is_root_rank if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round(FLAGS.data_loader_workers / udist.get_local_size()) is_root_rank = udist.is_master() else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count is_root_rank = True FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate # will be negative FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN / FLAGS.batch_size)) if is_root_rank: FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', ], ) setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with SummaryWriterManager(): train_val_test()
def main(): """Entry.""" FLAGS.test_only = True mc.setup_distributed() if udist.is_master(): FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S-eval")) setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with mc.SummaryWriterManager(): val()
def main(): """Entry.""" NUM_IMAGENET_TRAIN = 1281167 if FLAGS.dataset == 'cityscapes': NUM_IMAGENET_TRAIN = 2975 elif FLAGS.dataset == 'ade20k': NUM_IMAGENET_TRAIN = 20210 elif FLAGS.dataset == 'coco': NUM_IMAGENET_TRAIN = 149813 mc.setup_distributed(NUM_IMAGENET_TRAIN) if FLAGS.net_params and FLAGS.model_kwparams.task == 'segmentation': tag, input_channels, block1, block2, block3, block4, last_channel = FLAGS.net_params.split( '-') input_channels = [int(item) for item in input_channels.split('_')] block1 = [int(item) for item in block1.split('_')] block2 = [int(item) for item in block2.split('_')] block3 = [int(item) for item in block3.split('_')] block4 = [int(item) for item in block4.split('_')] last_channel = int(last_channel) inverted_residual_setting = [] for item in [block1, block2, block3, block4]: for _ in range(item[0]): inverted_residual_setting.append([ item[1], item[2:-int(len(item) / 2 - 1)], item[-int(len(item) / 2 - 1):] ]) FLAGS.model_kwparams.input_channel = input_channels FLAGS.model_kwparams.inverted_residual_setting = inverted_residual_setting FLAGS.model_kwparams.last_channel = last_channel if udist.is_master(): FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) # yapf: disable create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', 'output']) # yapf: enable setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with mc.SummaryWriterManager(): train_val_test()
def main(): """Entry.""" NUM_IMAGENET_TRAIN = 1281167 mc.setup_distributed(NUM_IMAGENET_TRAIN) if udist.is_master(): FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) # yapf: disable create_exp_dir(FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak']) # yapf: enable setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with mc.SummaryWriterManager(): train_val_test()
def get_model(): """Build and init model with wrapper for parallel.""" model_lib = importlib.import_module(FLAGS.model) model = model_lib.Model(**FLAGS.model_kwparams, input_size=FLAGS.image_size) if FLAGS.reset_parameters: init_method = FLAGS.get('reset_param_method', None) if init_method is None: pass # fall back to model's initialization elif init_method == 'slimmable': model.apply(mb.init_weights_slimmable) elif init_method == 'mnas': model.apply(mb.init_weights_mnas) else: raise ValueError('Unknown init method: {}'.format(init_method)) logging.info('Init model by: {}'.format(init_method)) if FLAGS.use_distributed: model_wrapper = udist.AllReduceDistributedDataParallel(model.cuda()) else: model_wrapper = torch.nn.DataParallel(model).cuda() return model, model_wrapper
def validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, phase): """Calibrate and validate.""" assert phase in ['test', 'val'] model_eval_wrapper = get_ema_model(ema, model_wrapper) # bn_calibration if FLAGS.get('bn_calibration', False): if not FLAGS.use_distributed: logging.warning( 'Only GPU0 is used when calibration when use DataParallel') with torch.no_grad(): _ = run_one_epoch(epoch, calib_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, max_iter=FLAGS.bn_calibration_steps, phase='bn_calibration') if FLAGS.use_distributed: udist.allreduce_bn(model_eval_wrapper) # val with torch.no_grad(): results = run_one_epoch(epoch, val_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, phase=phase) summary_bn(model_eval_wrapper, phase) return results, model_eval_wrapper
def train_val_test(): """Train and val.""" torch.backends.cudnn.benchmark = True # model model, model_wrapper = mc.get_model() ema = mc.setup_ema(model) criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() criterion_smooth = optim.CrossEntropyLabelSmooth( FLAGS.model_kwparams['num_classes'], FLAGS['label_smoothing'], reduction='none').cuda() # TODO(meijieru): cal loss on all GPUs instead only `cuda:0` when non # distributed if FLAGS.get('log_graph_only', False): if udist.is_master(): _input = torch.zeros(1, 3, FLAGS.image_size, FLAGS.image_size).cuda() _input = _input.requires_grad_(True) mc.summary_writer.add_graph(model_wrapper, (_input, ), verbose=True) return # check pretrained if FLAGS.pretrained: checkpoint = torch.load(FLAGS.pretrained, map_location=lambda storage, loc: storage) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) # update keys from external models if isinstance(checkpoint, dict) and 'model' in checkpoint: checkpoint = checkpoint['model'] if (hasattr(FLAGS, 'pretrained_model_remap_keys') and FLAGS.pretrained_model_remap_keys): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] logging.info('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) logging.info('Loaded model {}.'.format(FLAGS.pretrained)) optimizer = optim.get_optimizer(model_wrapper, FLAGS) # check resume training if FLAGS.resume: checkpoint = torch.load(os.path.join(FLAGS.resume, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) last_epoch = checkpoint['last_epoch'] lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch best_val = extract_item(checkpoint['best_val']) train_meters, val_meters = checkpoint['meters'] FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch if udist.is_master(): logging.info('Loaded checkpoint {} at epoch {}.'.format( FLAGS.resume, last_epoch)) else: lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) # last_epoch = lr_scheduler.last_epoch last_epoch = -1 best_val = 1. train_meters = mc.get_meters('train') val_meters = mc.get_meters('val') FLAGS._global_step = 0 if not FLAGS.resume and udist.is_master(): logging.info(model_wrapper) if FLAGS.profiling: if 'gpu' in FLAGS.profiling: mc.profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: mc.profiling(model, use_cuda=False) # data (train_transforms, val_transforms, test_transforms) = dataflow.data_transforms(FLAGS) (train_set, val_set, test_set) = dataflow.dataset(train_transforms, val_transforms, test_transforms, FLAGS) (train_loader, calib_loader, val_loader, test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS) if FLAGS.test_only and (test_loader is not None): if udist.is_master(): logging.info('Start testing.') test_meters = mc.get_meters('test') validate(last_epoch, calib_loader, test_loader, criterion, test_meters, model_wrapper, ema, 'test') return # already broadcast by AllReduceDistributedDataParallel # optimizer load same checkpoint/same initialization if udist.is_master(): logging.info('Start training.') for epoch in range(last_epoch + 1, FLAGS.num_epochs): # train results = run_one_epoch(epoch, train_loader, model_wrapper, criterion_smooth, optimizer, lr_scheduler, ema, train_meters, phase='train') # val results = validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, 'val') if results['top1_error'] < best_val: best_val = results['top1_error'] if udist.is_master(): save_status(model_wrapper, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model.pt')) logging.info( 'New best validation top1 error: {:.4f}'.format(best_val)) if udist.is_master(): # save latest checkpoint save_status(model_wrapper, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'latest_checkpoint.pt')) wandb.log( { "Validation Accuracy": 1. - results['top1_error'], "Best Validation Accuracy": 1. - best_val }, step=epoch) # NOTE(meijieru): from scheduler code, should be called after train/val # use stepwise scheduler instead # lr_scheduler.step() return
def validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, phase, segval=None, val_set=None): """Calibrate and validate.""" assert phase in ['test', 'val'] model_eval_wrapper = mc.get_ema_model(ema, model_wrapper) # bn_calibration if FLAGS.prune_params['method'] is not None: if FLAGS.get('bn_calibration', False): if not FLAGS.use_distributed: logging.warning( 'Only GPU0 is used when calibration when use DataParallel') with torch.no_grad(): _ = run_one_epoch(epoch, calib_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, max_iter=FLAGS.bn_calibration_steps, phase='bn_calibration') if FLAGS.use_distributed: udist.allreduce_bn(model_eval_wrapper) # val with torch.no_grad(): if FLAGS.model_kwparams.task == 'segmentation': if FLAGS.dataset == 'coco': results = 0 if udist.is_master(): results = keypoint_val(val_set, val_loader, model_eval_wrapper.module, criterion) else: assert segval is not None results = segval.run( epoch, val_loader, model_eval_wrapper.module if FLAGS.single_gpu_test else model_eval_wrapper, FLAGS) else: results = run_one_epoch(epoch, val_loader, model_eval_wrapper, criterion, None, None, None, None, val_meters, phase=phase) summary_bn(model_eval_wrapper, phase) return results, model_eval_wrapper
def train_val_test(): """Train and val.""" torch.backends.cudnn.benchmark = True # For acceleration # model model, model_wrapper = mc.get_model() ema = mc.setup_ema(model) criterion = torch.nn.CrossEntropyLoss(reduction='mean').cuda() criterion_smooth = optim.CrossEntropyLabelSmooth( FLAGS.model_kwparams['num_classes'], FLAGS['label_smoothing'], reduction='mean').cuda() if model.task == 'segmentation': criterion = CrossEntropyLoss().cuda() criterion_smooth = CrossEntropyLoss().cuda() if FLAGS.dataset == 'coco': criterion = JointsMSELoss(use_target_weight=True).cuda() criterion_smooth = JointsMSELoss(use_target_weight=True).cuda() if FLAGS.get('log_graph_only', False): if udist.is_master(): _input = torch.zeros(1, 3, FLAGS.image_size, FLAGS.image_size).cuda() _input = _input.requires_grad_(True) if isinstance(model_wrapper, (torch.nn.DataParallel, udist.AllReduceDistributedDataParallel)): mc.summary_writer.add_graph(model_wrapper.module, (_input, ), verbose=True) else: mc.summary_writer.add_graph(model_wrapper, (_input, ), verbose=True) return # check pretrained if FLAGS.pretrained: checkpoint = torch.load(FLAGS.pretrained, map_location=lambda storage, loc: storage) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) # update keys from external models if isinstance(checkpoint, dict) and 'model' in checkpoint: checkpoint = checkpoint['model'] if (hasattr(FLAGS, 'pretrained_model_remap_keys') and FLAGS.pretrained_model_remap_keys): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] if udist.is_master(): logging.info('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) if udist.is_master(): logging.info('Loaded model {}.'.format(FLAGS.pretrained)) optimizer = optim.get_optimizer(model_wrapper, FLAGS) # check resume training if FLAGS.resume: checkpoint = torch.load(os.path.join(FLAGS.resume, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper = checkpoint['model'].cuda() model = model_wrapper.module # model = checkpoint['model'].module optimizer = checkpoint['optimizer'] for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() # model_wrapper.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) if ema: # ema.load_state_dict(checkpoint['ema']) ema = checkpoint['ema'].cuda() ema.to(get_device(model)) last_epoch = checkpoint['last_epoch'] lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS, last_epoch=(last_epoch + 1) * FLAGS._steps_per_epoch) lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch best_val = extract_item(checkpoint['best_val']) train_meters, val_meters = checkpoint['meters'] FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch if udist.is_master(): logging.info('Loaded checkpoint {} at epoch {}.'.format( FLAGS.resume, last_epoch)) else: lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) # last_epoch = lr_scheduler.last_epoch last_epoch = -1 best_val = 1. if not FLAGS.distill: train_meters = mc.get_meters('train', FLAGS.prune_params['method']) val_meters = mc.get_meters('val') else: train_meters = mc.get_distill_meters('train', FLAGS.prune_params['method']) val_meters = mc.get_distill_meters('val') if FLAGS.model_kwparams.task == 'segmentation': best_val = 0. if not FLAGS.distill: train_meters = mc.get_seg_meters('train', FLAGS.prune_params['method']) val_meters = mc.get_seg_meters('val') else: train_meters = mc.get_seg_distill_meters( 'train', FLAGS.prune_params['method']) val_meters = mc.get_seg_distill_meters('val') FLAGS._global_step = 0 if not FLAGS.resume and udist.is_master(): logging.info(model_wrapper) assert FLAGS.profiling, '`m.macs` is used for calculating penalty' # if udist.is_master(): # model.apply(lambda m: print(m)) if FLAGS.profiling: if 'gpu' in FLAGS.profiling: mc.profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: mc.profiling(model, use_cuda=False) if FLAGS.dataset == 'cityscapes': (train_set, val_set, test_set) = seg_dataflow.cityscapes_datasets(FLAGS) segval = SegVal(num_classes=19) elif FLAGS.dataset == 'ade20k': (train_set, val_set, test_set) = seg_dataflow.ade20k_datasets(FLAGS) segval = SegVal(num_classes=150) elif FLAGS.dataset == 'coco': (train_set, val_set, test_set) = seg_dataflow.coco_datasets(FLAGS) # print(len(train_set), len(val_set)) # 149813 104125 segval = None else: # data (train_transforms, val_transforms, test_transforms) = dataflow.data_transforms(FLAGS) (train_set, val_set, test_set) = dataflow.dataset(train_transforms, val_transforms, test_transforms, FLAGS) segval = None (train_loader, calib_loader, val_loader, test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS) # get bn's weights if FLAGS.prune_params.use_transformer: FLAGS._bn_to_prune, FLAGS._bn_to_prune_transformer = prune.get_bn_to_prune( model, FLAGS.prune_params) else: FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params) rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params, FLAGS._steps_per_epoch) if FLAGS.test_only and (test_loader is not None): if udist.is_master(): logging.info('Start testing.') test_meters = mc.get_meters('test') validate(last_epoch, calib_loader, test_loader, criterion, test_meters, model_wrapper, ema, 'test') return # already broadcast by AllReduceDistributedDataParallel # optimizer load same checkpoint/same initialization if udist.is_master(): logging.info('Start training.') for epoch in range(last_epoch + 1, FLAGS.num_epochs): # train results = run_one_epoch(epoch, train_loader, model_wrapper, criterion_smooth, optimizer, lr_scheduler, ema, rho_scheduler, train_meters, phase='train') if (epoch + 1) % FLAGS.eval_interval == 0: # val results, model_eval_wrapper = validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, 'val', segval, val_set) if FLAGS.prune_params['method'] is not None and FLAGS.prune_params[ 'bn_prune_filter'] is not None: prune_threshold = FLAGS.model_shrink_threshold # 1e-3 masks = prune.cal_mask_network_slimming_by_threshold( get_prune_weights(model_eval_wrapper), prune_threshold ) # get mask for all bn weights (depth-wise) FLAGS._bn_to_prune.add_info_list('mask', masks) flops_pruned, infos = prune.cal_pruned_flops( FLAGS._bn_to_prune) log_pruned_info(mc.unwrap_model(model_eval_wrapper), flops_pruned, infos, prune_threshold) if not FLAGS.distill: if flops_pruned >= FLAGS.model_shrink_delta_flops \ or epoch == FLAGS.num_epochs - 1: ema_only = (epoch == FLAGS.num_epochs - 1) shrink_model(model_wrapper, ema, optimizer, FLAGS._bn_to_prune, prune_threshold, ema_only) model_kwparams = mb.output_network(mc.unwrap_model(model_wrapper)) if udist.is_master(): if FLAGS.model_kwparams.task == 'classification' and results[ 'top1_error'] < best_val: best_val = results['top1_error'] logging.info( 'New best validation top1 error: {:.4f}'.format( best_val)) save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) elif FLAGS.model_kwparams.task == 'segmentation' and FLAGS.dataset != 'coco' and results[ 'mIoU'] > best_val: best_val = results['mIoU'] logging.info('New seg mIoU: {:.4f}'.format(best_val)) save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) elif FLAGS.dataset == 'coco' and results > best_val: best_val = results logging.info('New Result: {:.4f}'.format(best_val)) save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) # save latest checkpoint save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'latest_checkpoint')) return
def train_val_test(): """Train and val.""" torch.backends.cudnn.benchmark = True # model model, model_wrapper = get_model() criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() criterion_smooth = optim.CrossEntropyLabelSmooth( FLAGS.model_kwparams['num_classes'], FLAGS['label_smoothing'], reduction='none').cuda() # TODO: cal loss on all GPUs instead only `cuda:0` when non # distributed ema = None if FLAGS.moving_average_decay > 0.0: if FLAGS.moving_average_decay_adjust: moving_average_decay = optim.ExponentialMovingAverage.adjust_momentum( FLAGS.moving_average_decay, FLAGS.moving_average_decay_base_batch / FLAGS.batch_size) else: moving_average_decay = FLAGS.moving_average_decay logging.info('Moving average for model parameters: {}'.format( moving_average_decay)) ema = optim.ExponentialMovingAverage(moving_average_decay) for name, param in model.named_parameters(): ema.register(name, param) # We maintain mva for batch norm moving mean and variance as well. for name, buffer in model.named_buffers(): if 'running_var' in name or 'running_mean' in name: ema.register(name, buffer) if FLAGS.get('log_graph_only', False): if is_root_rank: _input = torch.zeros(1, 3, FLAGS.image_size, FLAGS.image_size).cuda() _input = _input.requires_grad_(True) summary_writer.add_graph(model_wrapper, (_input, ), verbose=True) return # check pretrained if FLAGS.pretrained: checkpoint = torch.load(FLAGS.pretrained, map_location=lambda storage, loc: storage) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) # update keys from external models if isinstance(checkpoint, dict) and 'model' in checkpoint: checkpoint = checkpoint['model'] if (hasattr(FLAGS, 'pretrained_model_remap_keys') and FLAGS.pretrained_model_remap_keys): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] logging.info('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) logging.info('Loaded model {}.'.format(FLAGS.pretrained)) optimizer = optim.get_optimizer(model_wrapper, FLAGS) # check resume training if FLAGS.resume: checkpoint = torch.load(os.path.join(FLAGS.resume, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) last_epoch = checkpoint['last_epoch'] lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch best_val = extract_item(checkpoint['best_val']) train_meters, val_meters = checkpoint['meters'] FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch if is_root_rank: logging.info('Loaded checkpoint {} at epoch {}.'.format( FLAGS.resume, last_epoch)) else: lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) # last_epoch = lr_scheduler.last_epoch last_epoch = -1 best_val = 1. train_meters = get_meters('train') val_meters = get_meters('val') FLAGS._global_step = 0 if not FLAGS.resume and is_root_rank: logging.info(model_wrapper) assert FLAGS.profiling, '`m.macs` is used for calculating penalty' if FLAGS.profiling: if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) # data (train_transforms, val_transforms, test_transforms) = dataflow.data_transforms(FLAGS) (train_set, val_set, test_set) = dataflow.dataset(train_transforms, val_transforms, test_transforms, FLAGS) (train_loader, calib_loader, val_loader, test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS) # get bn's weights FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params) rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params, FLAGS._steps_per_epoch) if FLAGS.test_only and (test_loader is not None): if is_root_rank: logging.info('Start testing.') test_meters = get_meters('test') validate(last_epoch, calib_loader, test_loader, criterion, test_meters, model_wrapper, ema, 'test') return # already broadcast by AllReduceDistributedDataParallel # optimizer load same checkpoint/same initialization if is_root_rank: logging.info('Start training.') for epoch in range(last_epoch + 1, FLAGS.num_epochs): # train results = run_one_epoch(epoch, train_loader, model_wrapper, criterion_smooth, optimizer, lr_scheduler, ema, rho_scheduler, train_meters, phase='train') # val results, model_eval_wrapper = validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, 'val') if FLAGS.prune_params['method'] is not None: prune_threshold = FLAGS.model_shrink_threshold masks = prune.cal_mask_network_slimming_by_threshold( get_prune_weights(model_eval_wrapper), prune_threshold) FLAGS._bn_to_prune.add_info_list('mask', masks) flops_pruned, infos = prune.cal_pruned_flops(FLAGS._bn_to_prune) log_pruned_info(unwrap_model(model_eval_wrapper), flops_pruned, infos, prune_threshold) if flops_pruned >= FLAGS.model_shrink_delta_flops \ or epoch == FLAGS.num_epochs - 1: ema_only = (epoch == FLAGS.num_epochs - 1) shrink_model(model_wrapper, ema, optimizer, FLAGS._bn_to_prune, prune_threshold, ema_only) model_kwparams = mb.output_network(unwrap_model(model_wrapper)) if results['top1_error'] < best_val: best_val = results['top1_error'] if is_root_rank: save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) logging.info( 'New best validation top1 error: {:.4f}'.format(best_val)) if is_root_rank: # save latest checkpoint save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'latest_checkpoint')) # NOTE: from scheduler code, should be called after train/val # use stepwise scheduler instead # lr_scheduler.step() return