def buildStatManager(is_train, path_manager: PathManager, task_config: config.TaskConfig = None, train_config: config.TrainingConfig = None, test_config: config.TestConfig = None): if is_train: if train_config is None: train_config = config.train manager = TrainStatManager( stat_save_path=path_manager.trainStat(), model_save_path=path_manager.model(), save_latest_model=train_config.SaveLatest, train_report_iter=train_config.ValCycle, val_report_iter=train_config.ValEpisode, total_iter=train_config.TrainEpoch, metric_num=len(train_config.Metrics), criteria=train_config.Criteria, criteria_metric_index=0, # 默认的metric_index是0 metric_names=train_config.Metrics, verbose=train_config.Verbose) else: if test_config is None: test_config = config.test manager = TestStatManager(stat_save_path=path_manager.testStat(), test_report_iter=test_config.ReportIter, total_iter=test_config.Epoch, metric_num=len(test_config.Metrics), metric_names=test_config.Metrics, verbose=test_config.Verbose) return manager
def buildModel( path_manager: PathManager, task_config=None, model_params: config.ParamsConfig = None, loss_func=None, data_source=None, ): if model_params is None: model_params = config.params if task_config is None: task_config = config.task try: model = ModelSwitch[model_params.ModelName](path_manager, model_params, task_config, loss_func, data_source)\ .cuda() except KeyError: raise ValueError( "[ModelBuilder] No matched model implementation for '%s'" % model_params.ModelName) # 组装预训练的参数 if len(task_config.PreloadStateDictVersions) > 0: remained_model_keys = [n for n, _ in model.named_parameters()] unexpected_keys = [] for version in task_config.PreloadStateDictVersions: pm = PathManager(dataset=task_config.Dataset, version=version, model_name=model_params.ModelName) state_dict = torch.load(pm.model()) load_result = model.load_state_dict(state_dict, strict=False) for k in state_dict.keys(): if k not in load_result.unexpected_keys and k in remained_model_keys: remained_model_keys.remove(k) unexpected_keys.extend(load_result.unexpected_keys) if len(remained_model_keys) > 0: print(f'[buildModel] Preloading, unloaded keys:') pprint(remained_model_keys) if len(unexpected_keys) > 0: print(f'[buildModel] Preloading, unexpected keys:') pprint(unexpected_keys) return model