def testAdjustEmaRate(self): num_repeat = 3 for num_repeat in [1, 5, 9]: for momentum in [0.25, 0.9999]: momentum = 0.25 name = 'v' values = torch.randn(5) values_long = values.repeat(num_repeat, 1).permute( (1, 0)).contiguous().view(-1) ema = optim.ExponentialMovingAverage(momentum) ema.register(name, values[0]) for v in values: ema(name, v) lhs = ema.average(name) momentum = optim.ExponentialMovingAverage.adjust_momentum( momentum, num_repeat) ema = optim.ExponentialMovingAverage(momentum) ema.register(name, values[0]) for v in values_long: ema(name, v) rhs = ema.average(name) assertAllClose(lhs, rhs)
def testThrowValueError(self): ema = optim.ExponentialMovingAverage(0.25) try: ema.register('test', torch.tensor(5, dtype=torch.int)) except TypeError: pass else: raise AssertionError('Should throw ValueError')
def testDevice(self): name = 'zeros' ema = optim.ExponentialMovingAverage(0.25) ema.register(name, torch.zeros(5)) for device in ['cpu', 'cuda:0']: device = torch.device(device) ema = ema.to(device) cur_device = ema._shadow[name].device assert cur_device == device
def testAverageVariablesUpdateNumUpdates_Vector(self): ema = optim.ExponentialMovingAverage(0.25) name = 'tens' tens = _Repeat(10.0, dim=5) var = torch.tensor(tens) ema.register(name, var, zero_init=False) for num_updates in range(2): var.add_(1) ema(name, var, num_updates=num_updates) expected = _Repeat( (10 * 0.1 + 11 * 0.9) * 2.0 / 11.0 + 12 * 9.0 / 11.0, dim=5) assertAllClose(expected, ema.average(name))
def testSaveLoad(self): ema = optim.ExponentialMovingAverage(0.25) name = 'tens' tens = _Repeat(10.0, dim=5) var = torch.tensor(tens) ema.register(name, var, zero_init=False) state_dict = ema.state_dict() for name in ['info', 'shadow', 'param']: assert name in state_dict assert 'tens' in state_dict['shadow'] assertAllClose(state_dict['shadow']['tens'], var) ema.load_state_dict(state_dict) state_dict['param']['momentum'] = 0.5 self.assertWarns(RuntimeWarning, lambda: ema.load_state_dict(state_dict))
def testCompress(self): ema = optim.ExponentialMovingAverage(0.25) ema.register('var_prune', torch.arange(5).float()) ema.register('var_keep', torch.arange(5, 10).float()) ema('var_prune', torch.arange(5).float()) info = { 'var_old_name': 'var_prune', 'var_new_name': 'var_new', 'var_new': torch.randn(3), 'mask': torch.tensor([False, True, False, True, True]), 'mask_hook': lambda lhs, rhs, mask: lhs.data.copy_(rhs.data[mask]) } ema.compress_mask(info, verbose=False) self.assertTrue(info['var_new_name'] in ema._shadow) self.assertTrue(info['var_new_name'] in ema._info) self.assertTrue(info['var_old_name'] not in ema._shadow) self.assertTrue(info['var_old_name'] not in ema._info) self.assertEqual(ema._info[info['var_new_name']]['num_updates'], 1) assertAllClose(ema.average(info['var_new_name']), [1, 3, 4])
def setup_ema(model): """Setup EMA for model's weights.""" from utils import optim ema = None if FLAGS.moving_average_decay > 0.0: if FLAGS.moving_average_decay_adjust: moving_average_decay = \ optim.ExponentialMovingAverage.adjust_momentum( FLAGS.moving_average_decay, FLAGS.moving_average_decay_base_batch / FLAGS.batch_size) else: moving_average_decay = FLAGS.moving_average_decay logging.info('Moving average for model parameters: {}'.format( moving_average_decay)) ema = optim.ExponentialMovingAverage(moving_average_decay) for name, param in model.named_parameters(): ema.register(name, param) # We maintain mva for batch norm moving mean and variance as well. for name, buffer in model.named_buffers(): if 'running_var' in name or 'running_mean' in name: ema.register(name, buffer) return ema
def testAverageVariablesNumUpdates_Vector_Debias(self): # With num_updates 1, the decay applied is 0.1818 ema = optim.ExponentialMovingAverage(0.25, zero_debias=True) self._CheckDecay(ema, actual_decay=0.181818, dim=5, num_updates=1)
def testAverageVariablesNoNumUpdates_Vector_Debias(self): ema = optim.ExponentialMovingAverage(0.25, zero_debias=True) self._CheckDecay(ema, actual_decay=0.25, dim=5)
def testAverageVariablesNumUpdates_Vector(self): ema = optim.ExponentialMovingAverage(0.25) self._CheckDecay(ema, actual_decay=0.181818, dim=5, num_updates=1)
def testAverageVariablesNumUpdates_Scalar(self): # With num_updates 1, the decay applied is 0.1818 ema = optim.ExponentialMovingAverage(0.25) self._CheckDecay(ema, actual_decay=0.181818, dim=1, num_updates=1)
def testAverageVariablesNoNumUpdates_Scalar(self): ema = optim.ExponentialMovingAverage(0.25) self._CheckDecay(ema, actual_decay=0.25, dim=1)
def train_val_test(): """Train and val.""" torch.backends.cudnn.benchmark = True # model model, model_wrapper = get_model() criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() criterion_smooth = optim.CrossEntropyLabelSmooth( FLAGS.model_kwparams['num_classes'], FLAGS['label_smoothing'], reduction='none').cuda() # TODO: cal loss on all GPUs instead only `cuda:0` when non # distributed ema = None if FLAGS.moving_average_decay > 0.0: if FLAGS.moving_average_decay_adjust: moving_average_decay = optim.ExponentialMovingAverage.adjust_momentum( FLAGS.moving_average_decay, FLAGS.moving_average_decay_base_batch / FLAGS.batch_size) else: moving_average_decay = FLAGS.moving_average_decay logging.info('Moving average for model parameters: {}'.format( moving_average_decay)) ema = optim.ExponentialMovingAverage(moving_average_decay) for name, param in model.named_parameters(): ema.register(name, param) # We maintain mva for batch norm moving mean and variance as well. for name, buffer in model.named_buffers(): if 'running_var' in name or 'running_mean' in name: ema.register(name, buffer) if FLAGS.get('log_graph_only', False): if is_root_rank: _input = torch.zeros(1, 3, FLAGS.image_size, FLAGS.image_size).cuda() _input = _input.requires_grad_(True) summary_writer.add_graph(model_wrapper, (_input, ), verbose=True) return # check pretrained if FLAGS.pretrained: checkpoint = torch.load(FLAGS.pretrained, map_location=lambda storage, loc: storage) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) # update keys from external models if isinstance(checkpoint, dict) and 'model' in checkpoint: checkpoint = checkpoint['model'] if (hasattr(FLAGS, 'pretrained_model_remap_keys') and FLAGS.pretrained_model_remap_keys): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] logging.info('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) logging.info('Loaded model {}.'.format(FLAGS.pretrained)) optimizer = optim.get_optimizer(model_wrapper, FLAGS) # check resume training if FLAGS.resume: checkpoint = torch.load(os.path.join(FLAGS.resume, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) if ema: ema.load_state_dict(checkpoint['ema']) ema.to(get_device(model)) last_epoch = checkpoint['last_epoch'] lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) lr_scheduler.last_epoch = (last_epoch + 1) * FLAGS._steps_per_epoch best_val = extract_item(checkpoint['best_val']) train_meters, val_meters = checkpoint['meters'] FLAGS._global_step = (last_epoch + 1) * FLAGS._steps_per_epoch if is_root_rank: logging.info('Loaded checkpoint {} at epoch {}.'.format( FLAGS.resume, last_epoch)) else: lr_scheduler = optim.get_lr_scheduler(optimizer, FLAGS) # last_epoch = lr_scheduler.last_epoch last_epoch = -1 best_val = 1. train_meters = get_meters('train') val_meters = get_meters('val') FLAGS._global_step = 0 if not FLAGS.resume and is_root_rank: logging.info(model_wrapper) assert FLAGS.profiling, '`m.macs` is used for calculating penalty' if FLAGS.profiling: if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) # data (train_transforms, val_transforms, test_transforms) = dataflow.data_transforms(FLAGS) (train_set, val_set, test_set) = dataflow.dataset(train_transforms, val_transforms, test_transforms, FLAGS) (train_loader, calib_loader, val_loader, test_loader) = dataflow.data_loader(train_set, val_set, test_set, FLAGS) # get bn's weights FLAGS._bn_to_prune = prune.get_bn_to_prune(model, FLAGS.prune_params) rho_scheduler = prune.get_rho_scheduler(FLAGS.prune_params, FLAGS._steps_per_epoch) if FLAGS.test_only and (test_loader is not None): if is_root_rank: logging.info('Start testing.') test_meters = get_meters('test') validate(last_epoch, calib_loader, test_loader, criterion, test_meters, model_wrapper, ema, 'test') return # already broadcast by AllReduceDistributedDataParallel # optimizer load same checkpoint/same initialization if is_root_rank: logging.info('Start training.') for epoch in range(last_epoch + 1, FLAGS.num_epochs): # train results = run_one_epoch(epoch, train_loader, model_wrapper, criterion_smooth, optimizer, lr_scheduler, ema, rho_scheduler, train_meters, phase='train') # val results, model_eval_wrapper = validate(epoch, calib_loader, val_loader, criterion, val_meters, model_wrapper, ema, 'val') if FLAGS.prune_params['method'] is not None: prune_threshold = FLAGS.model_shrink_threshold masks = prune.cal_mask_network_slimming_by_threshold( get_prune_weights(model_eval_wrapper), prune_threshold) FLAGS._bn_to_prune.add_info_list('mask', masks) flops_pruned, infos = prune.cal_pruned_flops(FLAGS._bn_to_prune) log_pruned_info(unwrap_model(model_eval_wrapper), flops_pruned, infos, prune_threshold) if flops_pruned >= FLAGS.model_shrink_delta_flops \ or epoch == FLAGS.num_epochs - 1: ema_only = (epoch == FLAGS.num_epochs - 1) shrink_model(model_wrapper, ema, optimizer, FLAGS._bn_to_prune, prune_threshold, ema_only) model_kwparams = mb.output_network(unwrap_model(model_wrapper)) if results['top1_error'] < best_val: best_val = results['top1_error'] if is_root_rank: save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'best_model')) logging.info( 'New best validation top1 error: {:.4f}'.format(best_val)) if is_root_rank: # save latest checkpoint save_status(model_wrapper, model_kwparams, optimizer, ema, epoch, best_val, (train_meters, val_meters), os.path.join(FLAGS.log_dir, 'latest_checkpoint')) # NOTE: from scheduler code, should be called after train/val # use stepwise scheduler instead # lr_scheduler.step() return