def training_pipeline():
    model_num = 1
    best_model_num = 0
    best_model = CNNModel(best_model_num)
    best_model.save_weights()
    for _ in range(NUM_TRAINS):
        states, valids, improved_policy, win_loss = self_play(best_model_num)

        contender = CNNModel(model_num)

        contender.train_model(np.array(states, np.uint32),
                              np.array(valids, np.float32), np.array(win_loss),
                              np.array(improved_policy))
        contender_wins = bot_fight(best_model.model_num, contender.model_num)

        if contender_wins >= np.ceil(BOT_GAMES * 0.55):
            best_model = contender
            best_model_num = contender.model_num
        logging.info(
            f'best model: {best_model_num}, new model won {contender_wins}')
        best_model.save_weights(best=True)
        model_num += 1
Beispiel #2
0
def train_model(work_dir,
                pkl_data_dir,
                initial_epoch=0,
                batch_size=16,
                initial_lr=0.001,
                model_name='cnn',
                num_gpus=1):
    model_path = None
    if initial_epoch > 0:
        model_paths = glob(work_dir + '%s_e%02d*.hd5' %
                           (model_name, initial_epoch))
        if len(model_paths) != 1:
            print('cannot found model save file!!!!')
            assert (False)
        else:
            model_path = model_paths[0]
        initial_lr = initial_lr * (config['train']['lr_decay']
                                   **(initial_epoch - 1))
    else:
        initial_epoch = 0
    dataset = DataSet(data_dir=config['data']['data_dir'],
                      train=True,
                      pkl_file_dir=pkl_data_dir)
    cnn_model = CNNModel(model_path, True, initial_lr, num_gpus)

    train_data_gen = dataset.get_train_data_gen(batch_size=batch_size)
    train_data_num = dataset.get_train_data_num()
    eval_data_gen = dataset.get_eval_data_gen(batch_size=batch_size)
    eval_data_num = dataset.get_eval_data_num()
    cnn_model.train_model(train_data_gen,
                          train_data_num,
                          eval_data_gen,
                          eval_data_num,
                          batch_size,
                          work_dir,
                          model_name=model_name,
                          initial_epoch=0)