# -*- coding: utf-8 -*- # @Time : 2020/10/20 11:03 下午 # @Author : lishouxian # @Email : [email protected] # @File : main.py # @Software: PyCharm from engines.data import DataManager from engines.utils.logger import get_logger from engines.train import train from engines.predict import Predictor from engines.utils.word2vec import Word2VecUtils from config import mode, classifier_config, word2vec_config import json if __name__ == '__main__': logger = get_logger('./logs') # 训练分类器 if mode == 'train_classifier': logger.info(json.dumps(classifier_config, indent=2)) data_manage = DataManager(logger) logger.info('mode: train_classifier') logger.info('model: {}'.format(classifier_config['classifier'])) train(data_manage, logger) # 测试分类 elif mode == 'interactive_predict': logger.info(json.dumps(classifier_config, indent=2)) data_manage = DataManager(logger) logger.info('mode: predict_one') logger.info('model: {}'.format(classifier_config['classifier'])) predictor = Predictor(data_manage, logger) predictor.predict_one('warm start')
print('vocabs fold not found, creating...') if hasattr(configures, vocabs_dir): os.mkdir(configures.vocabs_dir) else: os.mkdir(configures.datasets_fold + '/vocabs') log_dir = 'log_dir' if not os.path.exists(configures.log_dir): print('log fold not found, creating...') if hasattr(configures, log_dir): os.mkdir(configures.log_dir) else: os.mkdir(configures.datasets_fold + '/vocabs') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Bert_nlu_joint') parser.add_argument('--config_file', default='system.config', help='Configuration File') args = parser.parse_args() configs = Configure(config_file=args.config_file) fold_check(configs) logger = get_logger(configs.log_dir) configs.show_data_summary(logger) set_env(configs) mode = configs.mode.lower() dataManager = DataManager(configs, logger) train(configs, dataManager, logger)
if not os.path.exists(configures.datasets_fold): print('datasets fold not found') exit(1) checkpoints_dir = 'checkpoints_dir' if not os.path.exists(configures.checkpoints_dir) or not hasattr( configures, checkpoints_dir): print('checkpoints fold not found, creating...') paths = configures.checkpoints_dir.split('/') if len(paths) == 2 and os.path.exists( paths[0]) and not os.path.exists(configures.checkpoints_dir): os.mkdir(configures.checkpoints_dir) else: os.mkdir('checkpoints') log_dir = 'log_dir' if not os.path.exists(configures.log_dir): print('log fold not found, creating...') if hasattr(configures, log_dir): os.mkdir(configures.log_dir) else: os.mkdir(configures.datasets_fold + '/vocabs') if __name__ == '__main__': logger = get_logger('logs') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') mode = 'train' if mode == 'train': train(device, logger)