def redesign_model(org_model, model_config, model_label, model_type='original'): logger.info('[{} model]'.format(model_label)) frozen_module_path_set = set(model_config.get('frozen_modules', list())) module_paths = model_config.get('sequential', list()) if not isinstance(module_paths, list) or len(module_paths) == 0: logger.info('Using the {} {} model'.format(model_type, model_label)) if len(frozen_module_path_set) > 0: logger.info('Frozen module(s): {}'.format(frozen_module_path_set)) isinstance_str = 'instance(' for frozen_module_path in frozen_module_path_set: if frozen_module_path.startswith( isinstance_str) and frozen_module_path.endswith(')'): target_cls = nn.__dict__[ frozen_module_path[len(isinstance_str):-1]] for m in org_model.modules(): if isinstance(m, target_cls): freeze_module_params(m) else: module = get_module(org_model, frozen_module_path) freeze_module_params(module) return org_model logger.info('Redesigning the {} model with {}'.format( model_label, module_paths)) if len(frozen_module_path_set) > 0: logger.info('Frozen module(s): {}'.format(frozen_module_path_set)) module_dict = OrderedDict() adaptation_dict = model_config.get('adaptations', dict()) for frozen_module_path in frozen_module_path_set: module = get_module(org_model, frozen_module_path) freeze_module_params(module) for module_path in module_paths: if module_path.startswith('+'): module_path = module_path[1:] adaptation_config = adaptation_dict[module_path] module = get_adaptation_module(adaptation_config['type'], **adaptation_config['params']) else: module = get_module(org_model, module_path) if module_path in frozen_module_path_set: freeze_module_params(module) add_submodule(module, module_path, module_dict) return build_sequential_container(module_dict)
def setup(self, train_config): # Set up train and val data loaders self.setup_data_loaders(train_config) # Define model used in this stage model_config = train_config.get('model', dict()) self.setup_model(model_config) # Define loss function used in this stage self.setup_loss(train_config) # Wrap models if necessary self.model =\ wrap_model(self.model, model_config, self.device, self.device_ids, self.distributed, self.model_any_frozen) if not model_config.get('requires_grad', True): logger.info('Freezing the whole model') freeze_module_params(self.model) # Set up optimizer and scheduler optim_config = train_config.get('optimizer', dict()) optimizer_reset = False if len(optim_config) > 0: optim_params_config = optim_config['params'] optim_params_config['lr'] *= self.lr_factor module_wise_params_configs = optim_config.get('module_wise_params', list()) if len(module_wise_params_configs) > 0: trainable_module_list = list() for module_wise_params_config in module_wise_params_configs: module_wise_params_dict = dict() module_wise_params_dict.update(module_wise_params_config['params']) if 'lr' in module_wise_params_dict: module_wise_params_dict['lr'] *= self.lr_factor module = get_module(self, module_wise_params_config['module']) module_wise_params_dict['params'] = module.parameters() trainable_module_list.append(module_wise_params_dict) else: trainable_module_list = nn.ModuleList([self.model]) self.optimizer = get_optimizer(trainable_module_list, optim_config['type'], optim_params_config) optimizer_reset = True scheduler_config = train_config.get('scheduler', None) if scheduler_config is not None and len(scheduler_config) > 0: self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['type'], scheduler_config['params']) elif optimizer_reset: self.lr_scheduler = None # Set up apex if you require mixed-precision training self.apex = False apex_config = train_config.get('apex', None) if apex_config is not None and apex_config.get('requires', False): if sys.version_info < (3, 0): raise RuntimeError('Apex currently only supports Python 3. Aborting.') if amp is None: raise RuntimeError('Failed to import apex. Please install apex from https://www.github.com/nvidia/apex ' 'to enable mixed-precision training.') self.model, self.optimizer =\ amp.initialize(self.model, self.optimizer, opt_level=apex_config['opt_level']) self.apex = True
def setup(self, train_config): # Set up train and val data loaders self.setup_data_loaders(train_config) # Define teacher and student models used in this stage teacher_config = train_config.get('teacher', dict()) student_config = train_config.get('student', dict()) self.setup_teacher_student_models(teacher_config, student_config) # Define loss function used in this stage self.setup_loss(train_config) # Freeze parameters if specified self.teacher_updatable = True if not teacher_config.get('requires_grad', True): logger.info('Freezing the whole teacher model') freeze_module_params(self.teacher_model) self.teacher_updatable = False if not student_config.get('requires_grad', True): logger.info('Freezing the whole student model') freeze_module_params(self.student_model) # Wrap models if necessary teacher_unused_parameters = teacher_config.get( 'find_unused_parameters', self.teacher_any_frozen) teacher_any_updatable = len( get_updatable_param_names(self.teacher_model)) > 0 self.teacher_model =\ wrap_model(self.teacher_model, teacher_config, self.device, self.device_ids, self.distributed, teacher_unused_parameters, teacher_any_updatable) student_unused_parameters = student_config.get( 'find_unused_parameters', self.student_any_frozen) student_any_updatable = len( get_updatable_param_names(self.student_model)) > 0 self.student_model =\ wrap_model(self.student_model, student_config, self.device, self.device_ids, self.distributed, student_unused_parameters, student_any_updatable) # Set up optimizer and scheduler optim_config = train_config.get('optimizer', dict()) optimizer_reset = False if len(optim_config) > 0: optim_params_config = optim_config['params'] if 'lr' in optim_params_config: optim_params_config['lr'] *= self.lr_factor module_wise_params_configs = optim_config.get( 'module_wise_params', list()) if len(module_wise_params_configs) > 0: trainable_module_list = list() for module_wise_params_config in module_wise_params_configs: module_wise_params_dict = dict() if isinstance( module_wise_params_config.get('params', None), dict): module_wise_params_dict.update( module_wise_params_config['params']) if 'lr' in module_wise_params_dict: module_wise_params_dict['lr'] *= self.lr_factor target_model = \ self.teacher_model if module_wise_params_config.get('is_teacher', False) else self.student_model module = get_module(target_model, module_wise_params_config['module']) module_wise_params_dict['params'] = module.parameters() trainable_module_list.append(module_wise_params_dict) else: trainable_module_list = nn.ModuleList([self.student_model]) if self.teacher_updatable: logger.info( 'Note that you are training some/all of the modules in the teacher model' ) trainable_module_list.append(self.teacher_model) filters_params = optim_config.get('filters_params', True) self.optimizer = \ get_optimizer(trainable_module_list, optim_config['type'], optim_params_config, filters_params) self.optimizer.zero_grad() self.max_grad_norm = optim_config.get('max_grad_norm', None) self.grad_accum_step = optim_config.get('grad_accum_step', 1) optimizer_reset = True scheduler_config = train_config.get('scheduler', None) if scheduler_config is not None and len(scheduler_config) > 0: self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['type'], scheduler_config['params']) self.scheduling_step = scheduler_config.get('scheduling_step', 0) elif optimizer_reset: self.lr_scheduler = None self.scheduling_step = None # Set up accelerator/apex if necessary self.apex = False apex_config = train_config.get('apex', None) if self.accelerator is not None: if self.teacher_updatable: self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \ self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader) else: self.teacher_model = self.teacher_model.to( self.accelerator.device) if self.accelerator.state.use_fp16: self.teacher_model = self.teacher_model.half() self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \ self.accelerator.prepare(self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader) elif apex_config is not None and apex_config.get('requires', False): if sys.version_info < (3, 0): raise RuntimeError( 'Apex currently only supports Python 3. Aborting.') if amp is None: raise RuntimeError( 'Failed to import apex. Please install apex from https://www.github.com/nvidia/apex ' 'to enable mixed-precision training.') self.student_model, self.optimizer =\ amp.initialize(self.student_model, self.optimizer, opt_level=apex_config['opt_level']) self.apex = True