def test_model_trainer_loop(local_ray, tmpdir):
    game = OthelloGame(6)
    nnet = OthelloNNet(game)
    s = SharedStorage.remote(nnet.get_weights())
    assert ray.get(s.get_revision.remote()) == 0
    r = MockedReplayBuffer.remote(games_to_play=4,
                                  games_to_use=4,
                                  folder=tmpdir)
    r.add_game_examples.remote(mock_example_data(game))

    model_trainer = ModelTrainer.options(num_gpus=0).remote(
        r, s, game, nnet.__class__, dict(args), selfplay_training_ratio=1)
    ray.get(model_trainer.start.remote())
    assert ray.get(s.get_revision.remote()) > 0
    assert ray.get(s.trained_enough.remote()) is True
def test_bare_model_player_from_checkpoint():
    game = OthelloGame(6)

    some_net = OthelloNNet(game)
    folder, filename = "/tmp/", "checkpoint_bare_model_player"
    some_net.save_checkpoint(folder, filename)
    del some_net

    player = BareModelPlayer(game,
                             OthelloNNet,
                             folder=folder,
                             filename=filename)
    board = game.get_init_board()
    action = player.play(board)
    assert action
def test_bare_model_player_from_model():
    game = OthelloGame(6)
    some_net = OthelloNNet(game)
    player = BareModelPlayer(game, some_net)
    board = game.get_init_board()
    action = player.play(board)
    assert action
def test_alpha_zero_player_from_model():
    game = OthelloGame(6)
    some_net = OthelloNNet(game)
    player = AlphaZeroPlayer(game, some_net)
    board = game.get_init_board()
    action = player.play(board)
    assert action
def test_alpha_zero_player_from_checkpoint():
    game = OthelloGame(6)

    some_net = OthelloNNet(game)
    folder, filename = "/tmp/", "checkpoint_alpha_zero_player"
    some_net.save_checkpoint(folder, filename)
    del some_net

    player = AlphaZeroPlayer(game,
                             OthelloNNet,
                             folder=folder,
                             filename=filename,
                             num_mcts_sims=4)
    board = game.get_init_board()
    action = player.play(board)
    assert action
def test_model_trainer_pit_reject_model(capsys, local_ray, tmpdir):
    game = OthelloGame(6)
    nnet = OthelloNNet(game)
    s = SharedStorage.remote(nnet.get_weights())
    assert ray.get(s.get_revision.remote()) == 0
    r = ReplayBuffer.remote(games_to_play=2, games_to_use=2, folder=tmpdir)
    r.add_game_examples.remote(mock_example_data(game))
    # provoke model rejection by tweaking updateThreshold to fail
    custom_args = dict(args, updateThreshold=1.1)
    model_trainer = ModelTrainer.options(num_gpus=0).remote(
        r, s, game, nnet.__class__, custom_args, pit_against_old_model=True)
    ray.get(model_trainer.train.remote())
    assert ray.get(s.get_revision.remote()) == 0
    out, _err = capsys.readouterr()
    assert "PITTING AGAINST PREVIOUS VERSION" in out
    assert "REJECTING NEW MODEL" in out
def test_self_play(local_ray, tmpdir):
    game = OthelloGame(6)
    nnet = OthelloNNet(game)
    s = SharedStorage.remote(nnet.get_weights())
    r = ReplayBuffer.remote(games_to_play=1, games_to_use=1, folder=tmpdir)
    assert ray.get(r.get_number_of_games_played.remote()) == 0
    self_play = SelfPlay.remote(r, s, game, nnet.__class__, dict(args))
    ray.get(self_play.start.remote())
    assert ray.get(r.get_number_of_games_played.remote()) == 1
    assert ray.get(r.played_enough.remote()) is True
    games = ray.get(ray.get(r.get_examples.remote()))
    assert len(games) == 1
    examples = games[0]
    assert len(examples) > 2
    board, policy, winner = examples[0]
    assert isinstance(board, type(game.get_init_board()))
    assert len(policy) == game.get_action_size()
    assert all(0 <= value <= 1 for value in policy)
    assert winner in [1, -1]
def test_coach(capsys, tmpdir):
    args.checkpoint = tmpdir
    game = OthelloGame(6)
    nnet = OthelloNNet(game)
    coach = Coach(game, nnet, args)
    coach.learn()