コード例 #1
0
def main_train():
    """
    训练模型

    :return:
    """
    print '[INFO] 解析配置...'

    parser = None
    config = None

    try:
        args, parser = get_train_args()
        config = process_config(args.config)
    except Exception as e:
        print '[Exception] 配置无效, %s' % e
        if parser:
            parser.print_help()
        print '[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json'
        exit(0)
    # config = process_config('configs/simple_mnist_config.json')

    print '[INFO] 加载数据...'
    dl = SimpleMnistDL(config=config)

    print '[INFO] 构造网络...'
    model = SimpleMnistModel(config=config)

    print '[INFO] 训练网络...'
    trainer = SimpleMnistTrainer(
        model=model.model,
        data=[dl.get_train_data(), dl.get_test_data()],
        config=config)
    trainer.train()
    print '[INFO] 训练完成...'
コード例 #2
0
def train_vgg_manga():

    print('[INFO] 解析配置…')
    parser = None
    config = None
    model_path = None

    try:
        args, parser = get_train_args()
        config = process_config(args.config)
        model_path = args.pre_train
    except Exception as e:
        print('[Exception] 配置无效, %s' % e)
        if parser:
            parser.print_help()
        print(
            '[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json'
        )
        exit(0)

    np.random.seed(47)

    print('[INFO] 加载数据…')
    dl = FaceNetDL(config=config)

    print('[INFO] 构造网络…')
    if config.backbone == 'vgg':
        print('[INFO] 使用 VGG 作为骨架')
    elif config.backbone == 'alexnet':
        print('[INFO] 使用 AlexNet 作为骨架')
    else:
        print('[INFO] 使用多层 CNN 作为骨架')

    if model_path != 'None':
        model = MangaFaceNetModel(config=config, model_path=model_path)
    else:
        model = MangaFaceNetModel(config=config)

    print('[INFO] 训练网络')
    trainer = VGGMangaTrainer(
        model=model.model,
        data=[dl.get_train_data(),
              dl.get_validation_data()],
        config=config)
    trainer.train()
    print('[INFO] 训练完成…')
コード例 #3
0
def train_vgg_manga():

    manga_dir = 'manga109_frame_face'
    print('[INFO] 解析配置…')
    parser = None
    config = None
    model_path = None

    try:
        args, parser = get_train_args()
        config = process_config(args.config)
        model_path = args.pre_train
    except Exception as e:
        print('[Exception] 配置无效, %s' % e)
        if parser:
            parser.print_help()
        print(
            '[Exception] 参考: python main_train.py -c configs/simple_mnist_config.json'
        )
        exit(0)

    np.random.seed(47)

    print('[INFO] 加载数据…')
    dl = VGGMangaDL(config=config, manga_dir=manga_dir)

    print('[INFO] 构造网络…')
    if model_path != 'None':
        model = VGGMangaSimpleModel(config=config, model_path=model_path)
    else:
        model = VGGMangaSimpleModel(config=config)

    print('[INFO] 训练网络')
    trainer = VGGMangaTrainer(
        model=model.model,
        data=[dl.get_train_data(),
              dl.get_validation_data()],
        config=config)
    trainer.train()
    print('[INFO] 训练完成…')
コード例 #4
0
ファイル: main_train.py プロジェクト: boyuzz/SISR-KD-Ensemble
def train_main():
    """
    训练模型

    :return:
    """
    print('[INFO] Retrieving configuration...')
    parser = None
    args = None
    config = None
    # TODO: modify the path of best checkpoint after training
    try:
        args, parser = get_train_args()
        # args.config = 'experiments/stacksr lr=1e-3 28init 2x/stacksr.json'
        # args.config = 'configs/lapsrn.json'
        config = process_config(args.config)
        shutil.copy2(args.config,
                     os.path.join("experiments", config['exp_name']))
    except Exception as e:
        print('[Exception] Configuration is invalid, %s' % e)
        if parser:
            parser.print_help()
        print(
            '[Exception] Refer to: python main_train.py -c configs/rrgun.json')
        exit(0)
    # config = process_config('configs/train_textcnn.json')
    # np.random.seed(config.seed)  # 固定随机数

    print('[INFO] Loading data...')
    torch.backends.cudnn.benchmark = True
    dl = ImageLoader(config=config['train_data_loader'])

    print('[INFO] Building graph...')
    try:
        Net = importlib.import_module('models.{}'.format(
            config['trainer']['net'])).Net
        model = Net(config=config['model'])
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        print_network(model)
    except ModuleNotFoundError:
        raise RuntimeWarning(
            "The model name is incorrect or does not exist! Please check!")

    print('[INFO] Training the graph...')
    trainer = SRTrainer(
        model=model,
        data={
            'train': dl.get_hdf5_sample_data(),
            'test': dl.get_test_data()
        },
        # data={'train': dl.get_hdf5_data(), 'test': dl.get_test_data()},
        config=config['trainer'])

    highest_score, best_model = trainer.train()
    with open(
            os.path.join("experiments", config['exp_name'], 'performance.txt'),
            'w') as f:
        f.writelines(str(highest_score))

    json_file = os.path.join("./experiments", config['exp_name'],
                             os.path.basename(args.config))
    with open(json_file, 'w') as file_out:
        config['trainer']['checkpoint'] = best_model
        json.dump(config, file_out, indent=2)

    print('[INFO] Training is completed.')
コード例 #5
0
def train_main():
    """
    训练模型

    :return:
    """
    print('[INFO] Retrieving configuration...')
    # import torch
    # print(torch.__version__)
    parser = None
    args = None
    config = None
    # TODO: modify the path of best checkpoint after training
    try:
        args, parser = get_train_args()
        # args.config = 'experiments/stacksr lr=1e-3 28init 3x/stacksr.json'
        # args.config = 'configs/lapsrn.json'
        config = process_config(args.config)
        shutil.copy2(args.config, os.path.join("experiments", config['exp_name']))
    except Exception as e:
        print('[Exception] Configuration is invalid, %s' % e)
        if parser:
            parser.print_help()
        print('[Exception] Refer to: python main_train.py -c configs/wmcnn.json')
        exit(0)
    # config = process_config('configs/train_textcnn.json')
    # np.random.seed(config.seed)  # 固定随机数

    print('[INFO] Loading data...')
    dl = ImageLoader(config=config['train_data_loader'])

    print('[INFO] Building graph...')
    try:
        Net = importlib.import_module('models.{}'.format(config['trainer']['net'])).Net
        model = Net(config=config['model'])
        print_network(model)
    except ModuleNotFoundError:
        raise RuntimeWarning("The model name is incorrect or does not exist! Please check!")

    # if config['distributed']:
    #     os.environ['MASTER_ADDR'] = '127.0.0.1'
    #     os.environ['MASTER_PORT'] = '29500'
    #     torch.distributed.init_process_group(backend='nccl', world_size=4, rank=2)

    print('[INFO] Training the graph...')
    # trainer = SRTrainer(
    #     model=model,
    #     data={'train': dl.get_train_data(), 'test': dl.get_test_data()},
    #     config=config['trainer'])
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    trainer = SRTrainer(
        model=model,
        data={'train': dl.get_wmcnn_hdf5_data(), 'test': dl.get_test_data()},
        # data={'train': dl.get_hdf5_data(), 'test': dl.get_test_data()},
        config=config['trainer'])

    highest_score, best_model = trainer.train()
    with open(os.path.join("experiments", config['exp_name'], 'performance.txt'), 'w') as f:
        f.writelines(str(highest_score))

    json_file = os.path.join("./experiments", config['exp_name'], os.path.basename(args.config))
    with open(json_file, 'w') as file_out:
        config['trainer']['checkpoint'] = best_model
        json.dump(config, file_out, indent=2)

    print('[INFO] Training is completed.')