def __init__(self, config, seed=42): global logger logger = shared_globals.logger config = AttrDefault(lambda: None, config) self.config = config self.datasets = {} self.data_loaders = {} self.use_swa = config.use_swa #self.run.info['epoch'] = 0 # set random seed torch.manual_seed(seed) np.random.seed(seed + 1) random.seed(seed + 2) self.min_lr = self.config.optim_config["min_lr"] if self.min_lr is None: self.min_lr = 0.0 print(self.min_lr) # making outout dirs models_outputdir = os.path.join(config.out_dir, "models") if not os.path.exists(config.out_dir): os.makedirs(config.out_dir) if not os.path.exists(models_outputdir): os.makedirs(models_outputdir) #self.run.info['out_path'] = config.out_dir # init_loggers self.init_loggers() self.dataset_manager = DatasetsManager(self.config['audiodataset']) # init Tensor board if self.config.tensorboard: tensorboard_write_path = config.tensorboard_write_path if not tensorboard_write_path: tensorboard_write_path = self.config.out_dir.replace( "out", "runs", 1) shared_globals.console.info("tensorboard run path: " + tensorboard_write_path) shared_globals.console.info("To monitor this experiment use:\n " + shared_globals.bcolors.FAIL + "tensorboard --logdir " + tensorboard_write_path + shared_globals.bcolors.ENDC) #self.run.info['tensorboard_path'] = tensorboard_write_path self.writer = SummaryWriter(tensorboard_write_path) # init multi gpu self.bare_model = load_model(config.model_config) if self.use_swa: self.swa_model = load_model(config.model_config) if self.config.use_gpu: self.swa_model.cuda() self.swa_n = 0 self.swa_c_epochs = config.swa_c_epochs self.swa_start = config.swa_start if self.config.use_gpu: self.bare_model.cuda() shared_globals.console.info( "Trainable model parameters {}, non-trainable {} ".format( count_parameters(self.bare_model), count_parameters(self.bare_model, False))) # DataParallel mode if not config.parallel_mode: self.model = self.bare_model elif config.parallel_mode == "distributed": torch.distributed.init_process_group( backend='nccl', world_size=1, rank=0, init_method='file://' + config.out_dir + "/shared_file") self.model = torch.nn.parallel.DistributedDataParallel( self.bare_model) else: self.model = torch.nn.DataParallel(self.bare_model) # self.model.cuda() # if load_model if config.get('load_model'): load_model_path = config.get('load_model') load_model_path = os.path.expanduser(load_model_path) shared_globals.console.info("Loading model located at: " + load_model_path) checkpoint = torch.load(load_model_path) self.model.load_state_dict(checkpoint['state_dict']) if self.use_swa: swa_state_dict = checkpoint.get('swa_state_dict', None) self.swa_n = checkpoint.get('swa_n', 0) if (swa_state_dict is not None) and not self.config.swa_model_load_same: self.swa_model.load_state_dict(swa_state_dict) else: shared_globals.console.warning( "No swa_state_dict loaded! same loaded") self.swa_model.load_state_dict(checkpoint['state_dict']) self.swa_n = 0 shared_globals.logger.info(str(self.model)) shared_globals.current_learning_rate = config.optim_config['base_lr'] self.optimizer, self.scheduler = create_optimizer( self.model.parameters(), config.optim_config) print("optimizer:", self.optimizer) loss_criterion_args = dict(config.loss_criterion_args) self.criterion = get_criterion( config.loss_criterion)(**loss_criterion_args) # init state inf_value = -float("inf") if self.config["optim_config"].get("model_selection", {}).get("select_min", False): inf_value = float("inf") self.state = { # 'config': self.config, 'state_dict': None, 'optimizer': None, 'epoch': 0, 'metrics': {}, 'best_metric_value': inf_value, 'best_epoch': 0, } self.first_batch_done = False # init dataset loaders self.init_loaders() if config.get('load_model'): if not config.get("load_model_no_test_first"): testing_result = {} for name in self.config.datasets: dataset_config = AttrDefault(lambda: None, self.config.datasets[name]) if dataset_config.testing: testing_result[name] = self.test( 0, name, dataset_config) # updating the state with new results self.update_state(testing_result, 0)
def update_state(self, testing_result, epoch): state = self.state state['epoch'] = epoch state['metrics'] = testing_result state['state_dict'] = self.bare_model.state_dict() model_path = os.path.join(self.config.out_dir, "models", 'last_model_{}.pth'.format(epoch)) if epoch > 250 and epoch % 5 == 0: print("saving at ", model_path) torch.save(state, model_path) selection_config = self.config["optim_config"].get( "model_selection", { "metric": "accuracy", "validation_set": "val", "patience": 30 }) # update best accuracy is_it_the_newbest_model = testing_result[selection_config['validation_set']][selection_config['metric']] > \ state[ 'best_metric_value'] if selection_config.get("select_min", False): is_it_the_newbest_model = testing_result[selection_config['validation_set']][selection_config['metric']] < \ state[ 'best_metric_value'] if is_it_the_newbest_model: state['state_dict'] = self.bare_model.state_dict() state['optimizer'] = self.optimizer.state_dict() state['best_metric_value'] = testing_result[ selection_config['validation_set']][selection_config['metric']] state['best_epoch'] = epoch shared_globals.console.info( "Epoch {}, found a new best model on set '{}', with {} {}". format(epoch, selection_config['validation_set'], state['best_metric_value'], selection_config['metric'])) state['best_metrics'] = testing_result state['patience_rest_epoch'] = epoch #self.run.info['best_metrics'] = testing_result #self.run.info['best_epoch'] = epoch model_path = os.path.join(self.config.out_dir, "models", 'model_{}.pth'.format(epoch)) best_model_path = os.path.join(self.config.out_dir, "models", 'model_best_state.pth') torch.save(state, model_path) torch.save(state, best_model_path) #self.run.info['best_model_path'] = best_model_path #self.run.info['best_metric_value'] = state['best_metric_value'] #self.run.info['best_metric_name'] = selection_config['validation_set'] + "." + selection_config['metric'] else: # logger.info( # "Model didn't improve {} for {} on validation set '{}', patience {} of {} (Best so far {} at epoch {} )".format( # selection_config['metric'], global_run_unique_identifier, # selection_config['validation_set'], str(global_patience_counter), # str(selection_config['patience']), str(state['best_metric_value']), str(state['best_epoch']))) patience = selection_config['patience'] - epoch + state[ 'patience_rest_epoch'] if patience <= 0: lr_min_limit = self.config["optim_config"].get( "model_selection", {}).get("lr_min_limit", None) if (lr_min_limit is None ) or shared_globals.current_learning_rate > lr_min_limit: shared_globals.current_learning_rate *= self.config[ "optim_config"].get("model_selection", {}).get("lr_decay_factor", 1.) if selection_config.get("load_optimizer_state"): raise NotImplementedError() else: if self.use_swa: shared_globals.console.warning( "SWA doesn't support LR decay via patience") optim_config = self.config['optim_config'] optim_config[ 'base_lr'] = shared_globals.current_learning_rate self.optimizer, self.scheduler = create_optimizer( self.model.parameters(), self.config.optim_config) else: self.config["optim_config"]['model_selection'][ 'no_best_model_reload'] = True best_model_path = os.path.join(self.config.out_dir, "models", 'model_best_state.pth') best_epoch_to_reload = "no_reload" if not self.config["optim_config"].get( "model_selection", {}).get("no_best_model_reload", False): checkpoint = torch.load(best_model_path) self.bare_model.load_state_dict(checkpoint['state_dict']) best_epoch_to_reload = state['best_epoch'] state['patience_rest_epoch'] = epoch shared_globals.console.info( "Patience out({}), Loaded from epoch {}, lr= {} ".format( epoch, best_epoch_to_reload, shared_globals.current_learning_rate))
def __init__(self, config, seed=42, mixed_precision_training=False): global logger logger = shared_globals.logger config = AttrDefault(lambda: None, config) self.config = config self.datasets = {} self.data_loaders = {} self.use_swa = config.use_swa self.prune_mode = config.get("prune_mode") #self.run.info['epoch'] = 0 # set random seed torch.manual_seed(seed) np.random.seed(seed + 1) random.seed(seed + 2) self.min_lr = self.config.optim_config["min_lr"] if self.min_lr is None: self.min_lr = 0.0 print(self.min_lr) # making outout dirs models_outputdir = os.path.join(config.out_dir, "models") if not os.path.exists(config.out_dir): os.makedirs(config.out_dir) if not os.path.exists(models_outputdir): os.makedirs(models_outputdir) #self.run.info['out_path'] = config.out_dir self.colab_mode = False self.mixed_precision_training = mixed_precision_training if mixed_precision_training: print("\n\nUsing mixed_precision_training\n\n ") self.scaler = torch.cuda.amp.GradScaler() # init_loggers self.init_loggers() self.dataset_manager = DatasetsManager(self.config['audiodataset']) # init Tensor board if self.config.tensorboard: tensorboard_write_path = config.tensorboard_write_path if not tensorboard_write_path: tensorboard_write_path = self.config.out_dir.replace( "out", "runs", 1) shared_globals.console.info("tensorboard run path: " + tensorboard_write_path) shared_globals.console.info("To monitor this experiment use:\n " + shared_globals.bcolors.FAIL + "tensorboard --logdir " + tensorboard_write_path + shared_globals.bcolors.ENDC) #self.run.info['tensorboard_path'] = tensorboard_write_path self.writer = SummaryWriter(tensorboard_write_path) # init multi gpu self.bare_model = load_model(config.model_config) print(self.bare_model) if self.use_swa: self.swa_model = load_model(config.model_config) if self.config.use_gpu: self.swa_model.cuda() self.swa_n = 0 self.swa_c_epochs = config.swa_c_epochs self.swa_start = config.swa_start # print number of parameters print("Trainable model parameters {}, non-trainable {} ".format( count_parameters(self.bare_model), count_parameters(self.bare_model, False))) print("Trainable model parameters non-zero {} ".format( count_non_zero_params(self.bare_model))) # move to gpu if self.config.use_gpu: self.bare_model.cuda() if self.prune_mode: try: true_params = self.bare_model.get_num_true_params() prunable_params = self.bare_model.get_num_prunable_params() shared_globals.console.info( "True model parameters {}, Prunable params {} ".format( true_params, prunable_params)) except AttributeError: raise true_params = prunable_params = count_parameters( self.bare_model) shared_globals.console.info( "WARNING:\n\nmodel doens't support true/prunable: True {}, Prunable params {} \n\n" .format(true_params, prunable_params)) if self.config.prune_percentage == -1: # -1 means auto must_prune_params = true_params - self.config.prune_percentage_target_params self.real_prune_percentage = must_prune_params / prunable_params if self.real_prune_percentage >= 0.9999: raise RuntimeError( "real_prune_percentage {} >= ~ 1.".format( self.real_prune_percentage)) if self.real_prune_percentage >= 0.9: print("\n\nWarning: very high real_prune_percentage\n\n", self.real_prune_percentage) if self.real_prune_percentage < 0: raise RuntimeError("real_prune_percentage {} <0.".format( self.real_prune_percentage)) print("\nWARNING: real_prune_percentage<0: ", self.real_prune_percentage, " setting to 0.1\n") self.real_prune_percentage = 0.1 else: self.real_prune_percentage = self.config.prune_percentage print("current prunning percentage=", self.real_prune_percentage) shared_globals.console.info( "\n\nTrainable model parameters {}, non-trainable {} \n\n".format( count_parameters(self.bare_model), count_parameters(self.bare_model, False))) # DataParallel mode if not config.parallel_mode: self.model = self.bare_model elif config.parallel_mode == "distributed": torch.distributed.init_process_group( backend='nccl', world_size=1, rank=0, init_method='file://' + config.out_dir + "/shared_file") self.model = torch.nn.parallel.DistributedDataParallel( self.bare_model) else: self.model = torch.nn.DataParallel(self.bare_model) # self.model.cuda() # if load_model if config.get('load_model'): load_model_path = config.get('load_model') load_model_path = os.path.expanduser(load_model_path) shared_globals.console.info("Loading model located at: " + load_model_path) checkpoint = torch.load(load_model_path) self.model.load_state_dict(checkpoint['state_dict']) if self.use_swa: swa_state_dict = checkpoint.get('swa_state_dict', None) self.swa_n = checkpoint.get('swa_n', 0) if (swa_state_dict is not None) and not self.config.swa_model_load_same: self.swa_model.load_state_dict(swa_state_dict) else: shared_globals.console.warning( "No swa_state_dict loaded! same loaded") self.swa_model.load_state_dict(checkpoint['state_dict']) self.swa_n = 0 shared_globals.logger.info(str(self.model)) shared_globals.current_learning_rate = config.optim_config['base_lr'] self.optimizer, self.scheduler = create_optimizer( self.model.parameters(), config.optim_config) print("optimizer:", self.optimizer) loss_criterion_args = dict(config.loss_criterion_args) self.criterion = get_criterion( config.loss_criterion)(**loss_criterion_args) # init state inf_value = -float("inf") if self.config["optim_config"].get("model_selection", {}).get("select_min", False): inf_value = float("inf") self.state = { # 'config': self.config, 'state_dict': None, 'optimizer': None, 'epoch': 0, 'metrics': {}, 'best_metric_value': inf_value, 'best_epoch': 0, } self.first_batch_done = False # init dataset loaders self.init_loaders() if config.get('load_model'): if not config.get("load_model_no_test_first"): testing_result = {} for name in self.config.datasets: dataset_config = AttrDefault(lambda: None, self.config.datasets[name]) if dataset_config.testing: testing_result[name] = self.test( 0, name, dataset_config) # updating the state with new results self.update_state(testing_result, 0)