def main_train(): """ 训练模型 :return: """ print '[INFO] 解析配置...' parser = None config = None try: args, parser = get_train_args() config = process_config(args.config) except Exception as e: print '[Exception] 配置无效, %s' % e if parser: parser.print_help() print '[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json' exit(0) # config = process_config('configs/simple_mnist_config.json') print '[INFO] 加载数据...' dl = SimpleMnistDL(config=config) print '[INFO] 构造网络...' model = SimpleMnistModel(config=config) print '[INFO] 训练网络...' trainer = SimpleMnistTrainer( model=model.model, data=[dl.get_train_data(), dl.get_test_data()], config=config) trainer.train() print '[INFO] 训练完成...'
def train_vgg_manga(): print('[INFO] 解析配置…') parser = None config = None model_path = None try: args, parser = get_train_args() config = process_config(args.config) model_path = args.pre_train except Exception as e: print('[Exception] 配置无效, %s' % e) if parser: parser.print_help() print( '[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json' ) exit(0) np.random.seed(47) print('[INFO] 加载数据…') dl = FaceNetDL(config=config) print('[INFO] 构造网络…') if config.backbone == 'vgg': print('[INFO] 使用 VGG 作为骨架') elif config.backbone == 'alexnet': print('[INFO] 使用 AlexNet 作为骨架') else: print('[INFO] 使用多层 CNN 作为骨架') if model_path != 'None': model = MangaFaceNetModel(config=config, model_path=model_path) else: model = MangaFaceNetModel(config=config) print('[INFO] 训练网络') trainer = VGGMangaTrainer( model=model.model, data=[dl.get_train_data(), dl.get_validation_data()], config=config) trainer.train() print('[INFO] 训练完成…')
def train_vgg_manga(): manga_dir = 'manga109_frame_face' print('[INFO] 解析配置…') parser = None config = None model_path = None try: args, parser = get_train_args() config = process_config(args.config) model_path = args.pre_train except Exception as e: print('[Exception] 配置无效, %s' % e) if parser: parser.print_help() print( '[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json' ) exit(0) np.random.seed(47) print('[INFO] 加载数据…') dl = VGGMangaDL(config=config, manga_dir=manga_dir) print('[INFO] 构造网络…') if model_path != 'None': model = VGGMangaSimpleModel(config=config, model_path=model_path) else: model = VGGMangaSimpleModel(config=config) print('[INFO] 训练网络') trainer = VGGMangaTrainer( model=model.model, data=[dl.get_train_data(), dl.get_validation_data()], config=config) trainer.train() print('[INFO] 训练完成…')
def train_main(): """ 训练模型 :return: """ print('[INFO] Retrieving configuration...') parser = None args = None config = None # TODO: modify the path of best checkpoint after training try: args, parser = get_train_args() # args.config = 'experiments/stacksr lr=1e-3 28init 2x/stacksr.json' # args.config = 'configs/lapsrn.json' config = process_config(args.config) shutil.copy2(args.config, os.path.join("experiments", config['exp_name'])) except Exception as e: print('[Exception] Configuration is invalid, %s' % e) if parser: parser.print_help() print( '[Exception] Refer to: python main_train.py -c configs/rrgun.json') exit(0) # config = process_config('configs/train_textcnn.json') # np.random.seed(config.seed) # 固定随机数 print('[INFO] Loading data...') torch.backends.cudnn.benchmark = True dl = ImageLoader(config=config['train_data_loader']) print('[INFO] Building graph...') try: Net = importlib.import_module('models.{}'.format( config['trainer']['net'])).Net model = Net(config=config['model']) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) print_network(model) except ModuleNotFoundError: raise RuntimeWarning( "The model name is incorrect or does not exist! Please check!") print('[INFO] Training the graph...') trainer = SRTrainer( model=model, data={ 'train': dl.get_hdf5_sample_data(), 'test': dl.get_test_data() }, # data={'train': dl.get_hdf5_data(), 'test': dl.get_test_data()}, config=config['trainer']) highest_score, best_model = trainer.train() with open( os.path.join("experiments", config['exp_name'], 'performance.txt'), 'w') as f: f.writelines(str(highest_score)) json_file = os.path.join("./experiments", config['exp_name'], os.path.basename(args.config)) with open(json_file, 'w') as file_out: config['trainer']['checkpoint'] = best_model json.dump(config, file_out, indent=2) print('[INFO] Training is completed.')
def train_main(): """ 训练模型 :return: """ print('[INFO] Retrieving configuration...') # import torch # print(torch.__version__) parser = None args = None config = None # TODO: modify the path of best checkpoint after training try: args, parser = get_train_args() # args.config = 'experiments/stacksr lr=1e-3 28init 3x/stacksr.json' # args.config = 'configs/lapsrn.json' config = process_config(args.config) shutil.copy2(args.config, os.path.join("experiments", config['exp_name'])) except Exception as e: print('[Exception] Configuration is invalid, %s' % e) if parser: parser.print_help() print('[Exception] Refer to: python main_train.py -c configs/wmcnn.json') exit(0) # config = process_config('configs/train_textcnn.json') # np.random.seed(config.seed) # 固定随机数 print('[INFO] Loading data...') dl = ImageLoader(config=config['train_data_loader']) print('[INFO] Building graph...') try: Net = importlib.import_module('models.{}'.format(config['trainer']['net'])).Net model = Net(config=config['model']) print_network(model) except ModuleNotFoundError: raise RuntimeWarning("The model name is incorrect or does not exist! Please check!") # if config['distributed']: # os.environ['MASTER_ADDR'] = '127.0.0.1' # os.environ['MASTER_PORT'] = '29500' # torch.distributed.init_process_group(backend='nccl', world_size=4, rank=2) print('[INFO] Training the graph...') # trainer = SRTrainer( # model=model, # data={'train': dl.get_train_data(), 'test': dl.get_test_data()}, # config=config['trainer']) os.environ['CUDA_LAUNCH_BLOCKING'] = '1' trainer = SRTrainer( model=model, data={'train': dl.get_wmcnn_hdf5_data(), 'test': dl.get_test_data()}, # data={'train': dl.get_hdf5_data(), 'test': dl.get_test_data()}, config=config['trainer']) highest_score, best_model = trainer.train() with open(os.path.join("experiments", config['exp_name'], 'performance.txt'), 'w') as f: f.writelines(str(highest_score)) json_file = os.path.join("./experiments", config['exp_name'], os.path.basename(args.config)) with open(json_file, 'w') as file_out: config['trainer']['checkpoint'] = best_model json.dump(config, file_out, indent=2) print('[INFO] Training is completed.')