def setup_distributed(num_images=None): """Setup distributed related parameters.""" # init distributed if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = \ FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round( FLAGS.data_loader_workers / udist.get_local_size() ) # Per_gpu_workers(the function will return the nearest integer else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = \ FLAGS.bn_calibration_per_gpu_batch_size * count if hasattr(FLAGS, 'base_lr'): FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) if num_images: # NOTE: don't drop last batch, thus must use ceil, otherwise learning # rate will be negative # the smallest integer not less than x FLAGS._steps_per_epoch = math.ceil(num_images / FLAGS.batch_size)
def train_val_test(): """train and val""" torch.backends.cudnn.benchmark = True # init distributed if getattr(FLAGS, 'distributed', False): init_dist() # seed if getattr(FLAGS, 'use_diff_seed', False): print('use diff seed is True') while not is_initialized(): print('Waiting for initialization ...') time.sleep(5) print('Expected seed: {}'.format( getattr(FLAGS, 'random_seed', 0) + get_rank())) set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank()) else: set_random_seed() # experiment setting experiment_setting = get_experiment_setting() # model model, model_wrapper = get_model() criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() if getattr(FLAGS, 'profiling_only', False): if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) return # data train_transforms, val_transforms, test_transforms = data_transforms() train_set, val_set, test_set = dataset(train_transforms, val_transforms, test_transforms) train_loader, val_loader, test_loader = data_loader( train_set, val_set, test_set) log_dir = FLAGS.log_dir log_dir = os.path.join(log_dir, experiment_setting) checkpoint = torch.load(os.path.join(log_dir, 'best_model.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer = get_optimizer(model_wrapper) mprint('Start testing.') test_meters = get_meters('test') with torch.no_grad(): run_one_epoch(-1, test_loader, model_wrapper, criterion, optimizer, test_meters, phase='test', ema=ema)
def main(): """Entry.""" # init distributed global is_root_rank if FLAGS.use_distributed: udist.init_dist() FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size FLAGS.data_loader_workers = round(FLAGS.data_loader_workers / udist.get_local_size()) is_root_rank = udist.is_master() else: count = torch.cuda.device_count() FLAGS.batch_size = count * FLAGS.per_gpu_batch_size FLAGS._loader_batch_size = FLAGS.batch_size if FLAGS.bn_calibration: FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count is_root_rank = True FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch) # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate # will be negative FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN / FLAGS.batch_size)) if is_root_rank: FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir( FLAGS.log_dir, FLAGS.config_path, blacklist_dirs=[ 'exp', '.git', 'pretrained', 'tmp', 'deprecated', 'bak', ], ) setup_logging(FLAGS.log_dir) for k, v in _ENV_EXPAND.items(): logging.info('Env var expand: {} to {}'.format(k, v)) logging.info(FLAGS) set_random_seed(FLAGS.get('random_seed', 0)) with SummaryWriterManager(): train_val_test()
def get_model(): """get model""" model_lib = importlib.import_module(FLAGS.model) model = model_lib.Model(FLAGS.num_classes) if getattr(FLAGS, 'distributed', False): gpu_id = init_dist() if getattr(FLAGS, 'distributed_all_reduce', False): model_wrapper = AllReduceDistributedDataParallel(model.cuda()) else: model_wrapper = torch.nn.parallel.DistributedDataParallel( model.cuda(), [gpu_id], gpu_id) else: model_wrapper = torch.nn.DataParallel(model).cuda() return model, model_wrapper
def get_model(): """get model""" model_lib = importlib.import_module(FLAGS.model) model = model_lib.Model(FLAGS.num_classes, input_size=FLAGS.image_size) inputs = Variable(torch.randn((2, 3, 32, 32))) output = model(inputs) print(output) print(output.size()) #input('pause') if getattr(FLAGS, 'distributed', False): gpu_id = init_dist() if getattr(FLAGS, 'distributed_all_reduce', False): # seems faster model_wrapper = AllReduceDistributedDataParallel(model.cuda()) else: model_wrapper = torch.nn.parallel.DistributedDataParallel( model.cuda(), [gpu_id], gpu_id) else: model_wrapper = torch.nn.DataParallel(model).cuda() return model, model_wrapper
def train_val_test(): """train and val""" torch.backends.cudnn.benchmark = True # init distributed if getattr(FLAGS, 'distributed', False): init_dist() # seed if getattr(FLAGS, 'use_diff_seed', False) and not getattr(FLAGS, 'stoch_valid', False): print('use diff seed is True') while not is_initialized(): print('Waiting for initialization ...') time.sleep(5) print('Expected seed: {}'.format(getattr(FLAGS, 'random_seed', 0) + get_rank())) set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank()) else: set_random_seed() # experiment setting experiment_setting = get_experiment_setting() # model model, model_wrapper = get_model() criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() if getattr(FLAGS, 'profiling_only', False): if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) return # data train_transforms, val_transforms, test_transforms = data_transforms() train_set, val_set, test_set = dataset( train_transforms, val_transforms, test_transforms) train_loader, val_loader, test_loader = data_loader( train_set, val_set, test_set) log_dir = FLAGS.log_dir log_dir = os.path.join(log_dir, experiment_setting) # full precision pretrained if getattr(FLAGS, 'fp_pretrained_file', None): checkpoint = torch.load( FLAGS.fp_pretrained_file, map_location=lambda storage, loc: storage) # update keys from external models if type(checkpoint) == dict and 'model' in checkpoint: checkpoint = checkpoint['model'] if getattr(FLAGS, 'pretrained_model_remap_keys', False): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] mprint('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_dict = model_wrapper.state_dict() #checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict} # switch bn for k in list(checkpoint.keys()): if 'bn' in k: for bn_idx in range(len(FLAGS.bits_list)): k_new = k.split('bn')[0] + 'bn' + k.split('bn')[1][0] + str(bn_idx) + k.split('bn')[1][2:] mprint(k) mprint(k_new) checkpoint[k_new] = model_dict[k] if getattr(FLAGS, 'switch_alpha', False): for k, v in checkpoint.items(): if 'alpha' in k and checkpoint[k].size() != model_dict[k].size(): #checkpoint[k] = checkpoint[k].repeat(model_dict[k].size()) checkpoint[k] = nn.Parameter(torch.stack([checkpoint[k] for _ in range(model_dict[k].size()[0])])) # remove unexpected keys for k in list(checkpoint.keys()): if k not in model_dict.keys(): checkpoint.pop(k) model_dict.update(checkpoint) model_wrapper.load_state_dict(model_dict) mprint('Loaded full precision model {}.'.format(FLAGS.fp_pretrained_file)) # check pretrained if FLAGS.pretrained_file: pretrained_dir = FLAGS.pretrained_dir pretrained_dir = os.path.join(pretrained_dir, experiment_setting) pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file) checkpoint = torch.load( pretrained_file, map_location=lambda storage, loc: storage) # update keys from external models if type(checkpoint) == dict and 'model' in checkpoint: checkpoint = checkpoint['model'] if getattr(FLAGS, 'pretrained_model_remap_keys', False): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] mprint('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) mprint('Loaded model {}.'.format(pretrained_file)) optimizer = get_optimizer(model_wrapper) if FLAGS.test_only and (test_loader is not None): mprint('Start profiling.') if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) mprint('Start testing.') test_meters = get_meters('test') with torch.no_grad(): run_one_epoch( -1, test_loader, model_wrapper, criterion, optimizer, test_meters, phase='test') return # check resume training if os.path.exists(os.path.join(log_dir, 'latest_checkpoint.pt')): checkpoint = torch.load( os.path.join(log_dir, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) last_epoch = checkpoint['last_epoch'] if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']: lr_scheduler = get_lr_scheduler(optimizer, len(train_loader)) lr_scheduler.last_epoch = last_epoch * len(train_loader) else: lr_scheduler = get_lr_scheduler(optimizer) lr_scheduler.last_epoch = last_epoch best_val = checkpoint['best_val'] train_meters, val_meters = checkpoint['meters'] mprint('Loaded checkpoint {} at epoch {}.'.format( log_dir, last_epoch)) else: if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']: lr_scheduler = get_lr_scheduler(optimizer, len(train_loader)) else: lr_scheduler = get_lr_scheduler(optimizer) last_epoch = lr_scheduler.last_epoch best_val = 1. train_meters = get_meters('train') val_meters = get_meters('val') # if start from scratch, print model and do profiling mprint(model_wrapper) if getattr(FLAGS, 'profiling', False): if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) if getattr(FLAGS, 'log_dir', None): try: os.makedirs(log_dir) except OSError: pass mprint('Start training.') for epoch in range(last_epoch+1, FLAGS.num_epochs): if FLAGS.lr_scheduler in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']: lr_sched = lr_scheduler else: lr_sched = None # For PyTorch 1.1+, comment the following line #lr_scheduler.step() # train mprint(' train '.center(40, '*')) run_one_epoch( epoch, train_loader, model_wrapper, criterion, optimizer, train_meters, phase='train', scheduler=lr_sched) # val mprint(' validation '.center(40, '~')) if val_meters is not None: val_meters['best_val'].cache(best_val) with torch.no_grad(): top1_error = run_one_epoch( epoch, val_loader, model_wrapper, criterion, optimizer, val_meters, phase='val') if is_master(): if top1_error < best_val: best_val = top1_error torch.save( { 'model': model_wrapper.state_dict(), }, os.path.join(log_dir, 'best_model.pt')) mprint('New best validation top1 error: {:.3f}'.format(best_val)) # save latest checkpoint torch.save( { 'model': model_wrapper.state_dict(), 'optimizer': optimizer.state_dict(), 'last_epoch': epoch, 'best_val': best_val, 'meters': (train_meters, val_meters), }, os.path.join(log_dir, 'latest_checkpoint.pt')) # For PyTorch 1.0 or earlier, comment the following two lines if FLAGS.lr_scheduler not in ['exp_decaying_iter', 'cos_annealing_iter', 'multistep_iter']: lr_scheduler.step() if is_master(): profiling(model, use_cuda=True) return
def train_val_test(): """train and val""" torch.backends.cudnn.benchmark = True # init distributed if getattr(FLAGS, 'distributed', False): init_dist() # seed #if getattr(FLAGS, 'use_diff_seed', False): #if getattr(FLAGS, 'use_diff_seed', False) and not FLAGS.test_only: if getattr(FLAGS, 'use_diff_seed', False) and not getattr(FLAGS, 'stoch_valid', False): print('use diff seed is True') while not is_initialized(): print('Waiting for initialization ...') time.sleep(5) print('Expected seed: {}'.format(getattr(FLAGS, 'random_seed', 0) + get_rank())) set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank()) else: set_random_seed() # experiment setting experiment_setting = get_experiment_setting() # model model, model_wrapper = get_model() criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() if getattr(FLAGS, 'profiling_only', False): if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) return # ema_decay = getattr(FLAGS, 'ema_decay', None) if ema_decay: ema = EMA(ema_decay) ema.shadow_register(model_wrapper) #for name, param in model.named_parameters(): # if param.requires_grad: # ema.register(name, param.data) #bn_idx = 0 #for m in model.modules(): # if isinstance(m, nn.BatchNorm2d): # ema.register('bn{}_mean'.format(bn_idx), m.running_mean) # ema.register('bn{}_var'.format(bn_idx), m.running_var) # bn_idx += 1 else: ema = None # data train_transforms, val_transforms, test_transforms = data_transforms() train_set, val_set, test_set = dataset( train_transforms, val_transforms, test_transforms) train_loader, val_loader, test_loader = data_loader( train_set, val_set, test_set) log_dir = FLAGS.log_dir log_dir = os.path.join(log_dir, experiment_setting) io = UltronIO('hdfs://haruna/home') # full precision pretrained if getattr(FLAGS, 'fp_pretrained_file', None): checkpoint = io.torch_load( FLAGS.fp_pretrained_file, map_location=lambda storage, loc: storage) # update keys from external models if type(checkpoint) == dict and 'model' in checkpoint: checkpoint = checkpoint['model'] if getattr(FLAGS, 'pretrained_model_remap_keys', False): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] mprint('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_dict = model_wrapper.state_dict() #checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict} # remove unexpected keys for k in list(checkpoint.keys()): if k not in model_dict.keys(): checkpoint.pop(k) model_dict.update(checkpoint) model_wrapper.load_state_dict(model_dict) mprint('Loaded full precision model {}.'.format(FLAGS.fp_pretrained_file)) # check pretrained if FLAGS.pretrained_file and FLAGS.pretrained_dir: pretrained_dir = FLAGS.pretrained_dir #pretrained_dir = os.path.join(pretrained_dir, experiment_setting) pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file) checkpoint = io.torch_load( pretrained_file, map_location=lambda storage, loc: storage) # update keys from external models #if type(checkpoint) == dict and 'model' in checkpoint: # checkpoint = checkpoint['model'] if getattr(FLAGS, 'pretrained_model_remap_keys', False): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] mprint('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint # filter lamda_w and lamda_a args: pretrained_dict = {} for k,v in checkpoint['model'].items(): if 'lamda_w' in k or 'lamda_a' in k: checkpoint['model'][k] = v.repeat(model_wrapper.state_dict()[k].size()) model_wrapper.load_state_dict(checkpoint['model']) mprint('Loaded model {}.'.format(pretrained_file)) optimizer = get_optimizer(model_wrapper) if FLAGS.test_only and (test_loader is not None): mprint('Start testing.') ema = checkpoint.get('ema', None) test_meters = get_meters('test') with torch.no_grad(): run_one_epoch( -1, test_loader, model_wrapper, criterion, optimizer, test_meters, phase='test', ema=ema) return # check resume training if io.check_path(os.path.join(log_dir, 'latest_checkpoint.pt')): checkpoint = io.torch_load( os.path.join(log_dir, 'latest_checkpoint.pt'), map_location=lambda storage, loc: storage) model_wrapper.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) last_epoch = checkpoint['last_epoch'] if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']: lr_scheduler = get_lr_scheduler(optimizer, len(train_loader)) lr_scheduler.last_epoch = last_epoch * len(train_loader) else: lr_scheduler = get_lr_scheduler(optimizer) lr_scheduler.last_epoch = last_epoch best_val = checkpoint['best_val'] train_meters, val_meters = checkpoint['meters'] ema = checkpoint.get('ema', None) mprint('Loaded checkpoint {} at epoch {}.'.format( log_dir, last_epoch)) else: if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']: lr_scheduler = get_lr_scheduler(optimizer, len(train_loader)) else: lr_scheduler = get_lr_scheduler(optimizer) last_epoch = lr_scheduler.last_epoch best_val = 1. train_meters = get_meters('train') val_meters = get_meters('val') # if start from scratch, print model and do profiling mprint(model_wrapper) if getattr(FLAGS, 'profiling', False): if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) if getattr(FLAGS, 'log_dir', None): try: io.create_folder(log_dir) except OSError: pass mprint('Start training.') for epoch in range(last_epoch+1, FLAGS.num_epochs): if FLAGS.lr_scheduler in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']: lr_sched = lr_scheduler else: lr_sched = None # For PyTorch 1.1+, comment the following line #lr_scheduler.step() # train mprint(' train '.center(40, '*')) run_one_epoch( epoch, train_loader, model_wrapper, criterion, optimizer, train_meters, phase='train', ema=ema, scheduler=lr_sched) # val mprint(' validation '.center(40, '~')) if val_meters is not None: val_meters['best_val'].cache(best_val) with torch.no_grad(): if epoch == getattr(FLAGS,'hard_assign_epoch', float('inf')): mprint('Start to use hard assigment') setattr(FLAGS, 'hard_assignment', True) lower_offset = -1 higher_offset = 0 setattr(FLAGS, 'hard_offset', 0) with_ratio = 0.01 bitops, bytesize = profiling(model, use_cuda=True) search_trials = 10 trial = 0 if getattr(FLAGS,'weight_only', False): target_bytesize = getattr(FLAGS, 'target_size', 0) while trial < search_trials: trial += 1 if bytesize - target_bytesize > with_ratio * target_bytesize: higher_offset = FLAGS.hard_offset elif bytesize - target_bytesize < -with_ratio * target_bytesize: lower_offset = FLAGS.hard_offset else: break FLAGS.hard_offset = (higher_offset + lower_offset) /2 bitops, bytesize = profiling(model, use_cuda=True) else: target_bitops = getattr(FLAGS, 'target_bitops',0) while trial < search_trials: trial += 1 if bitops - target_bitops > with_ratio *target_bitops: higher_offset = FLAGS.hard_offset elif bitops - target_bitops < -with_ratio * target_bitops: lower_offset = FLAGS.hard_offset else: break FLAGS.hard_offset = (higher_offset + lower_offset) /2 bitops, bytesize = profiling(model, use_cuda=True) bit_discretizing(model_wrapper) setattr(FLAGS,'hard_offset', 0) top1_error = run_one_epoch( epoch, val_loader, model_wrapper, criterion, optimizer, val_meters, phase='val', ema=ema) if is_master(): if top1_error < best_val: best_val = top1_error io.torch_save( os.path.join(log_dir, 'best_model.pt'), { 'model': model_wrapper.state_dict(), } ) mprint('New best validation top1 error: {:.3f}'.format(best_val)) # save latest checkpoint io.torch_save( os.path.join(log_dir, 'latest_checkpoint.pt'), { 'model': model_wrapper.state_dict(), 'optimizer': optimizer.state_dict(), 'last_epoch': epoch, 'best_val': best_val, 'meters': (train_meters, val_meters), 'ema': ema, }) # For PyTorch 1.0 or earlier, comment the following two lines if FLAGS.lr_scheduler not in ['exp_decaying_iter', 'gaussian_iter', 'cos_annealing_iter', 'butterworth_iter', 'mixed_iter']: lr_scheduler.step() if is_master(): profiling(model, use_cuda=True) for m in model.modules(): if hasattr(m, 'alpha'): mprint(m, m.alpha) if hasattr(m, 'lamda_w'): mprint(m, m.lamda_w) if hasattr(m, 'lamda_a'): mprint(m, m.lamda_a) return
def train_val_test(): """train and val""" torch.backends.cudnn.benchmark = True # init distributed if getattr(FLAGS, 'distributed', False): init_dist() # seed if getattr(FLAGS, 'use_diff_seed', False) and not getattr(FLAGS, 'stoch_valid', False): print('use diff seed is True') while not is_initialized(): print('Waiting for initialization ...') time.sleep(5) print('Expected seed: {}'.format( getattr(FLAGS, 'random_seed', 0) + get_rank())) set_random_seed(getattr(FLAGS, 'random_seed', 0) + get_rank()) else: set_random_seed() if getattr(FLAGS, 'adjust_lr', False): eta_dict = { 32: 1.0, 16: 1.0, 8: 1.0, 7: 0.99, 6: 0.98, 5: 0.97, 4: 0.94, 3: 0.88, 2: 0.77, 1: 0.58 } eta = lambda b: eta_dict[b] # noqa: E731 else: eta = None # experiment setting experiment_setting = get_experiment_setting() mprint('stoch_valid: {}, bn_calib_stoch_valid: {}'.format( getattr(FLAGS, 'stoch_valid', False), getattr(FLAGS, 'bn_calib_stoch_valid', False))) # model model, model_wrapper = get_model() criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda() if getattr(FLAGS, 'profiling_only', False): if 'gpu' in FLAGS.profiling: profiling(model, use_cuda=True) if 'cpu' in FLAGS.profiling: profiling(model, use_cuda=False) return # ema_decay = getattr(FLAGS, 'ema_decay', None) if ema_decay: ema = EMA(ema_decay) ema.shadow_register(model_wrapper) #for name, param in model.named_parameters(): # if param.requires_grad: # ema.register(name, param.data) #bn_idx = 0 #for m in model.modules(): # if isinstance(m, nn.BatchNorm2d): # ema.register('bn{}_mean'.format(bn_idx), m.running_mean) # ema.register('bn{}_var'.format(bn_idx), m.running_var) # bn_idx += 1 else: ema = None # data train_transforms, val_transforms, test_transforms = data_transforms() train_set, val_set, test_set = dataset(train_transforms, val_transforms, test_transforms) train_loader, val_loader, test_loader = data_loader( train_set, val_set, test_set) log_dir = FLAGS.log_dir log_dir = os.path.join(log_dir, experiment_setting) # check pretrained if FLAGS.pretrained_file: pretrained_dir = FLAGS.pretrained_dir pretrained_dir = os.path.join(pretrained_dir, experiment_setting) pretrained_file = os.path.join(pretrained_dir, FLAGS.pretrained_file) checkpoint = torch.load(pretrained_file, map_location=lambda storage, loc: storage) # update keys from external models if type(checkpoint) == dict and 'model' in checkpoint: checkpoint = checkpoint['model'] if getattr(FLAGS, 'pretrained_model_remap_keys', False): new_checkpoint = {} new_keys = list(model_wrapper.state_dict().keys()) old_keys = list(checkpoint.keys()) for key_new, key_old in zip(new_keys, old_keys): new_checkpoint[key_new] = checkpoint[key_old] mprint('remap {} to {}'.format(key_new, key_old)) checkpoint = new_checkpoint model_wrapper.load_state_dict(checkpoint) mprint('Loaded model {}.'.format(pretrained_file)) optimizer = get_optimizer(model_wrapper) cal_meters = get_meters('cal', single_sample=True) mprint('Start calibration.') run_one_epoch(-1, train_loader, model_wrapper, criterion, optimizer, cal_meters, phase='cal', ema=ema, single_sample=True) mprint('Start validation after calibration.') with torch.no_grad(): run_one_epoch(-1, val_loader, model_wrapper, criterion, optimizer, cal_meters, phase='val', ema=ema, single_sample=True) return
def init(config): random_seed = config.random_seed np.random.seed(random_seed) torch.manual_seed(random_seed) random.seed(random_seed) dist.init_dist(config.distributed.enable, port=config.port)