コード例 #1
0
def init_botvbot(bot1idx,
                 bot2idx,
                 num_games,
                 args,
                 ai1_option,
                 ai2_option,
                 game_option,
                 *,
                 act_name="act",
                 viz=False):
    # print('ai1 option:')
    # print(ai1_option.info())
    # print('ai2 option:')
    # print(ai2_option.info())
    # print('game option:')
    # print(game_option.info())
    total_games = num_games
    batchsize = min(32, max(total_games // 2, 1))

    act1_dc = tube.DataChannel(act_name + "1", batchsize, 1)
    act2_dc = tube.DataChannel(act_name + "2", batchsize, 1)
    context = tube.Context()
    idx2utype = [
        minirts.UnitType.SPEARMAN,
        minirts.UnitType.SWORDMAN,
        minirts.UnitType.CAVALRY,
        minirts.UnitType.DRAGON,
        minirts.UnitType.ARCHER,
    ]

    if game_option.seed == 777:
        print("Using random seeds...")

    game_id = 0
    rnd_num = random.randint(1, num_games - 1)
    for i in range(num_games):
        if game_option.seed == 777:
            seed = random.randint(1, 123456)
        else:
            seed = game_option.seed

        g_option = minirts.RTSGameOption(game_option)
        g_option.seed = seed + i
        g_option.game_id = str(i)
        g = minirts.RTSGame(g_option)

        bot1 = minirts.MediumAI(ai1_option, 0, None, idx2utype[bot1idx], 1)
        bot2 = minirts.MediumAI(ai2_option, 0, None, idx2utype[bot2idx],
                                1)  # Utype + tower

        g.add_bot(bot1)
        g.add_bot(bot2)

        if viz and i == rnd_num:
            g.add_default_spectator()

        context.push_env_thread(g)
        game_id += 1

    return context, act1_dc, act2_dc
コード例 #2
0
def create_drift_games(num_sp,
                       num_rb,
                       args,
                       ai1_option,
                       ai2_option,
                       game_option,
                       *,
                       act_name="act",
                       viz=False):
    # print('ai1 option:')
    # print(ai1_option.info())
    # print('ai2 option:')
    # print(ai2_option.info())
    # print('game option:')
    # print(game_option.info())

    if game_option.seed == 777:
        print("Using random seeds...")

    total_games = num_sp + num_rb
    batchsize = min(32, max(total_games // 2, 1))

    act1_dc = tube.DataChannel(act_name + "1", batchsize, 1)
    act2_dc = tube.DataChannel(act_name + "2", batchsize, 1)
    context = tube.Context()
    idx2utype = [
        minirts.UnitType.SPEARMAN,
        minirts.UnitType.SWORDMAN,
        minirts.UnitType.CAVALRY,
        minirts.UnitType.DRAGON,
        minirts.UnitType.ARCHER,
    ]

    game_id = 0
    rnd_num = random.randint(0, num_rb - 1)
    for i in range(num_rb):
        if game_option.seed == 777:
            seed = random.randint(1, 123456)
        else:
            seed = game_option.seed

        bot1, g = create_game(act1_dc, ai1_option, game_option, game_id, seed)

        utype = idx2utype[random.randint(0, len(idx2utype) - 1)]
        bot2 = minirts.MediumAI(ai2_option, 0, None, utype, 1)  # Utype + tower

        g.add_bot(bot1)
        g.add_bot(bot2)

        if viz and i == rnd_num:
            g.add_default_spectator()

        context.push_env_thread(g)
        game_id += 1

    return context, act1_dc, act2_dc
コード例 #3
0
def init_games(num_games,
               ai1_option,
               ai2_option,
               game_option,
               *,
               act_name="act"):
    # print('ai1 option:')
    # print(ai1_option.info())
    # print('ai2 option:')
    # print(ai2_option.info())
    # print('game option:')
    # print(game_option.info())

    batchsize = min(32, max(num_games // 2, 1))
    act1_dc = tube.DataChannel(act_name + "1", batchsize, 1)
    act2_dc = tube.DataChannel(act_name + "2", batchsize, 1)
    context = tube.Context()
    idx2utype = [
        minirts.UnitType.SPEARMAN,
        minirts.UnitType.SWORDMAN,
        minirts.UnitType.CAVALRY,
        minirts.UnitType.DRAGON,
        minirts.UnitType.ARCHER,
    ]

    if game_option.seed == 777:
        print("Using random seeds...")

    for i in range(num_games):
        g_option = minirts.RTSGameOption(game_option)
        if game_option.seed == 777:
            print("Using random seeds...")
            seed = random.randint(1, 123456)
        else:
            seed = game_option.seed

        g_option.seed = seed + i
        g_option.game_id = str(i)
        if game_option.save_replay_prefix:
            g_option.save_replay_prefix = (game_option.save_replay_prefix +
                                           "_0_" + str(i))

        g = minirts.RTSGame(g_option)
        bot1 = minirts.CheatExecutorAI(ai1_option, 0, None, act1_dc)
        bot2 = minirts.CheatExecutorAI(ai2_option, 0, None, act2_dc)
        # utype = idx2utype[i % len(idx2utype)]
        # bot2 = minirts.MediumAI(ai2_option, 0, None, utype, False)

        g.add_bot(bot1)
        g.add_bot(bot2)
        context.push_env_thread(g)

    return context, act1_dc, act2_dc
コード例 #4
0
def create_tp_environment(
    seed_generator: Iterator[int],
    game_params: GameParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    pure_mcts: bool,
) -> Tuple[tube.Context, Optional[tube.DataChannel], Callable[[], int]]:
    human_first = execution_params.human_first
    time_ratio = execution_params.time_ratio
    total_time = execution_params.total_time
    context = tube.Context()
    actor_channel = (None if pure_mcts else tube.DataChannel(
        "act", simulation_params.num_actor, 1))
    game = create_game(
        game_params,
        num_episode=1,
        seed=next(seed_generator),
        eval_mode=True,
        per_thread_batchsize=0,
    )
    player = create_player(
        seed_generator=seed_generator,
        game=game,
        num_actor=simulation_params.num_actor,
        num_rollouts=simulation_params.num_rollouts,
        pure_mcts=pure_mcts,
        actor_channel=actor_channel,
        assembler=None,
        human_mode=True,
        total_time=total_time,
        time_ratio=time_ratio,
    )
    tp_player = polygames.TPPlayer()
    if game.is_one_player_game():
        game.add_tp_player(tp_player)
    else:
        if human_first:
            game.add_tp_player(tp_player)
            game.add_eval_player(player)
        else:
            game.add_eval_player(player)
            game.add_tp_player(tp_player)

    context.push_env_thread(game)

    def get_result_for_tp_player():
        nonlocal game, human_first
        return game.get_result()[not human_first]

    return context, actor_channel, get_result_for_tp_player
コード例 #5
0
ファイル: human_coach.py プロジェクト: tzs930/minirts
def create_game(ai1_option, ai2_option, game_option, *, act_name='act'):
    print('ai1 option:')
    print(ai1_option.info())
    print('ai2 option:')
    print(ai2_option.info())
    print('game option:')
    print(game_option.info())

    act_dc = tube.DataChannel(act_name, 1, -1)
    context = tube.Context()
    g = minirts.RTSGame(game_option)
    bot1 = minirts.CheatExecutorAI(ai1_option, 0, None, act_dc)
    bot2 = minirts.MediumAI(ai2_option, 0, None, minirts.UnitType.INVALID_UNITTYPE, False)
    g.add_bot(bot1)
    g.add_bot(bot2)
    g.add_default_spectator()

    context.push_env_thread(g)
    return context, act_dc
コード例 #6
0
def create_game(num_games,
                ai1_option,
                ai2_option,
                game_option,
                *,
                act_name='act'):
    print('ai1 option:')
    print(ai1_option.info())
    print('ai2 option:')
    print(ai2_option.info())
    print('game option:')
    print(game_option.info())

    batchsize = min(32, max(num_games // 2, 1))
    act1_dc = tube.DataChannel(act_name + '1', batchsize, 1)
    act2_dc = tube.DataChannel(act_name + '2', batchsize, 1)
    context = tube.Context()
    idx2utype = [
        minirts.UnitType.SPEARMAN,
        minirts.UnitType.SWORDMAN,
        minirts.UnitType.CAVALRY,
        minirts.UnitType.DRAGON,
        minirts.UnitType.ARCHER,
    ]

    for i in range(num_games):
        g_option = minirts.RTSGameOption(game_option)
        g_option.seed = game_option.seed + i
        if game_option.save_replay_prefix:
            g_option.save_replay_prefix = game_option.save_replay_prefix + str(
                i)

        g = minirts.RTSGame(g_option)
        bot1 = minirts.CheatExecutorAI(ai1_option, 0, None, act1_dc)
        bot2 = minirts.CheatExecutorAI(ai2_option, 0, None, act2_dc)
        # utype = idx2utype[i % len(idx2utype)]
        # bot2 = minirts.MediumAI(ai2_option, 0, None, utype, False)
        g.add_bot(bot1)
        g.add_bot(bot2)
        context.push_env_thread(g)

    return context, act1_dc, act2_dc
コード例 #7
0
# Copyright (c) Facebook, Inc. and its affiliates.
コード例 #8
0
def create_training_environment(
    seed_generator: Iterator[int], model_path: Path, device: str,
    game_params: GameParams, simulation_params: SimulationParams,
    execution_params: ExecutionParams, model
) -> Tuple[tube.Context, polygames.ModelManager, Callable[[], List[int]],
           bool]:
    games = []
    context = tube.Context()
    print("Game generation device: {}".format(device))
    listen_ep = execution_params.listen
    connect_ep = execution_params.connect
    opponent_model_path = execution_params.opponent_model_path
    is_server = listen_ep != ""
    is_client = connect_ep != ""
    print("is_server is ", is_server)
    print("is_client is ", is_client)
    model_manager = polygames.ModelManager(
        simulation_params.act_batchsize,
        str(device),
        simulation_params.replay_capacity,
        next(seed_generator),
        str(model_path),
        simulation_params.train_channel_timeout_ms,
        simulation_params.train_channel_num_slots,
    )
    model_manager.set_find_batch_size_max_bs(simulation_params.bsfinder_max_bs)
    model_manager.set_find_batch_size_max_ms(simulation_params.bsfinder_max_ms)
    if is_server:
        model_manager.start_server(listen_ep)
    if is_client:
        model_manager.start_client(connect_ep)
    if is_client and is_server:
        raise RuntimeError(
            "Client and server parameters have both been specified")

    rnn_state_shape = getattr(model, "rnn_state_shape", [])
    logit_value = getattr(model, "logit_value", False)

    print("rnn_state_shape is ", rnn_state_shape)

    if simulation_params.num_threads != 0:
        polygames.init_threads(simulation_params.num_threads)

    opgame = None
    op_rnn_state_shape = None
    op_rnn_seqlen = None
    op_logit_value = None
    if not is_server:
        if opponent_model_path:
            print("loading opponent model")
            checkpoint = utils.load_checkpoint(
                checkpoint_path=opponent_model_path)
            opmodel = create_model(
                game_params=checkpoint["game_params"],
                model_params=checkpoint["model_params"],
                resume_training=True,
                model_state_dict=checkpoint["model_state_dict"],
            )
            opponent_model_path = execution_params.checkpoint_dir / "model_opponent.pt"
            opmodel.save(str(opponent_model_path))
            opgame = create_game(
                checkpoint["game_params"],
                num_episode=-1,
                seed=next(seed_generator),
                eval_mode=False,
            )
            op_rnn_state_shape = getattr(opmodel, "rnn_state_shape", [])
            op_rnn_seqlen = 0
            if hasattr(checkpoint["execution_params"], "rnn_seqlen"):
                op_rnn_seqlen = checkpoint["execution_params"].rnn_seqlen
            op_logit_value = getattr(opmodel, "logit_value", False)
        model_manager_opponent = polygames.ModelManager(
            simulation_params.act_batchsize,
            str(device),
            simulation_params.replay_capacity,
            next(seed_generator),
            str(opponent_model_path)
            if opponent_model_path else str(model_path),
            simulation_params.train_channel_timeout_ms,
            simulation_params.train_channel_num_slots,
        )
        model_manager_opponent.set_find_batch_size_max_bs(
            simulation_params.bsfinder_max_bs)
        model_manager_opponent.set_find_batch_size_max_ms(
            simulation_params.bsfinder_max_ms)
        print("tournament_mode is " + str(execution_params.tournament_mode))
        if execution_params.tournament_mode:
            model_manager_opponent.set_is_tournament_opponent(True)
        if opponent_model_path:
            model_manager_opponent.set_dont_request_model_updates(True)
        if is_client:
            model_manager_opponent.start_client(connect_ep)
    if not is_server:
        train_channel = model_manager.get_train_channel()
        actor_channel = model_manager.get_act_channel()

        op_actor_channel = actor_channel
        if model_manager_opponent is not None:
            op_actor_channel = model_manager_opponent.get_act_channel()

        for i in range(simulation_params.num_game):
            game = create_game(
                game_params,
                num_episode=-1,
                seed=next(seed_generator),
                eval_mode=False,
                per_thread_batchsize=simulation_params.per_thread_batchsize,
                rewind=simulation_params.rewind,
                predict_end_state=game_params.predict_end_state,
                predict_n_states=game_params.predict_n_states,
            )
            player_1 = create_player(
                seed_generator=seed_generator,
                game=game,
                player=game_params.player,
                num_actor=simulation_params.num_actor,
                num_rollouts=simulation_params.num_rollouts,
                pure_mcts=False,
                actor_channel=actor_channel,
                model_manager=model_manager,
                human_mode=False,
                sample_before_step_idx=simulation_params.
                sample_before_step_idx,
                randomized_rollouts=simulation_params.randomized_rollouts,
                sampling_mcts=simulation_params.sampling_mcts,
                rnn_state_shape=rnn_state_shape,
                rnn_seqlen=execution_params.rnn_seqlen,
                logit_value=logit_value)
            player_1.set_name("dev")
            if game.is_one_player_game():
                game.add_player(player_1, train_channel)
            else:
                player_2 = create_player(
                    seed_generator=seed_generator,
                    game=opgame if opgame is not None else game,
                    player=game_params.player,
                    num_actor=simulation_params.num_actor,
                    num_rollouts=simulation_params.num_rollouts,
                    pure_mcts=False,
                    actor_channel=op_actor_channel,
                    model_manager=model_manager_opponent,
                    human_mode=False,
                    sample_before_step_idx=simulation_params.
                    sample_before_step_idx,
                    randomized_rollouts=simulation_params.randomized_rollouts,
                    sampling_mcts=simulation_params.sampling_mcts,
                    rnn_state_shape=op_rnn_state_shape
                    if op_rnn_state_shape is not None else rnn_state_shape,
                    rnn_seqlen=op_rnn_seqlen if op_rnn_seqlen is not None else
                    execution_params.rnn_seqlen,
                    logit_value=op_logit_value
                    if op_logit_value is not None else logit_value)
                player_2.set_name("opponent")
                if next(seed_generator) % 2 == 0:
                    game.add_player(player_1, train_channel, game, player_1)
                    game.add_player(player_2, train_channel,
                                    opgame if opgame is not None else game,
                                    player_1)
                else:
                    game.add_player(player_2, train_channel,
                                    opgame if opgame is not None else game,
                                    player_1)
                    game.add_player(player_1, train_channel, game, player_1)

            context.push_env_thread(game)
            games.append(game)

    def get_train_reward() -> Callable[[], List[int]]:
        nonlocal games
        nonlocal opgame
        reward = []
        for game in games:
            reward.append(game.get_result()[0])
        if opgame is not None:
            reward.append(opgame.get_result()[0])

        return reward

    return context, model_manager, get_train_reward, is_client
コード例 #9
0
def create_training_environment(
    seed_generator: Iterator[int], model_path: Path,
    game_generation_devices: List[str], game_params: GameParams,
    simulation_params: SimulationParams, execution_params: ExecutionParams
) -> Tuple[tube.Context, tube.ChannelAssembler, Callable[[], List[int]], bool]:
    games = []
    context = tube.Context()
    print("Game generation devices: {}".format(game_generation_devices))
    server_listen_endpoint = execution_params.server_listen_endpoint
    server_connect_hostname = execution_params.server_connect_hostname
    opponent_model_path = execution_params.opponent_model_path
    is_server = server_listen_endpoint != ""
    is_client = server_connect_hostname != ""
    print("is_server is ", is_server)
    print("is_client is ", is_client)
    assembler = tube.ChannelAssembler(
        simulation_params.act_batchsize,
        len(game_generation_devices) if not is_server else 0,
        game_generation_devices,
        simulation_params.replay_capacity,
        next(seed_generator),
        str(model_path),
        simulation_params.train_channel_timeout_ms,
        simulation_params.train_channel_num_slots,
    )
    if is_server:
        assembler.start_server(server_listen_endpoint)
    if is_client:
        assembler.start_client(server_connect_hostname)
    if is_client and is_server:
        raise RuntimeError(
            "Client and server parameters have both been specified")
    if not is_server:
        if opponent_model_path:
            print("loading opponent model")
            checkpoint = utils.load_checkpoint(
                checkpoint_path=opponent_model_path)
            model = create_model(
                game_params=checkpoint["game_params"],
                model_params=checkpoint["model_params"],
                resume_training=True,
                model_state_dict=checkpoint["model_state_dict"],
            )
            opponent_model_path = execution_params.checkpoint_dir / "model_opponent.pt"
            model.save(str(opponent_model_path))
        assembler_opponent = tube.ChannelAssembler(
            simulation_params.act_batchsize,
            len(game_generation_devices) if not is_server else 0,
            game_generation_devices,
            simulation_params.replay_capacity,
            next(seed_generator),
            str(opponent_model_path)
            if opponent_model_path else str(model_path),
            simulation_params.train_channel_timeout_ms,
            simulation_params.train_channel_num_slots,
        )
        assembler_opponent.set_is_tournament_opponent(True)
        if opponent_model_path:
            assembler_opponent.set_dont_request_model_updates(True)
        if is_client:
            assembler_opponent.start_client(server_connect_hostname)
    if not is_server:
        train_channel = assembler.get_train_channel()
        actor_channels = assembler.get_act_channels()
        actor_channel = actor_channels[0]

        for i in range(simulation_params.num_game):
            game = create_game(
                game_params,
                num_episode=-1,
                seed=next(seed_generator),
                eval_mode=False,
                per_thread_batchsize=simulation_params.per_thread_batchsize,
            )
            if simulation_params.per_thread_batchsize > 0:
                player_1 = create_player(
                    seed_generator=seed_generator,
                    game=game,
                    num_actor=simulation_params.num_actor,
                    num_rollouts=simulation_params.num_rollouts,
                    pure_mcts=False,
                    actor_channel=actor_channel,
                    assembler=assembler,
                    human_mode=False,
                )
                player_1.set_name("dev")
                if game.is_one_player_game():
                    game.add_player(player_1, train_channel)
                else:
                    player_2 = create_player(
                        seed_generator=seed_generator,
                        game=game,
                        num_actor=simulation_params.num_actor,
                        num_rollouts=simulation_params.num_rollouts,
                        pure_mcts=False,
                        actor_channel=actor_channel,
                        assembler=assembler_opponent,
                        human_mode=False,
                    )
                    player_2.set_name("opponent")
                    if i % 2 == 0:
                        game.add_player(player_1, train_channel)
                        game.add_player(player_2, train_channel)
                    else:
                        game.add_player(player_2, train_channel)
                        game.add_player(player_1, train_channel)
            else:
                player_1 = create_player(
                    seed_generator=seed_generator,
                    game=game,
                    num_actor=simulation_params.num_actor,
                    num_rollouts=simulation_params.num_rollouts,
                    pure_mcts=False,
                    actor_channel=actor_channels[i % len(actor_channels)],
                    assembler=None,
                    human_mode=False,
                )
                game.add_player(player_1, train_channel)
                if not game.is_one_player_game():
                    player_2 = create_player(
                        seed_generator=seed_generator,
                        game=game,
                        num_actor=simulation_params.num_actor,
                        num_rollouts=simulation_params.num_rollouts,
                        pure_mcts=False,
                        actor_channel=actor_channels[i % len(actor_channels)],
                        assembler=None,
                        human_mode=False,
                    )
                    game.add_player(player_2, train_channel)

            context.push_env_thread(game)
            games.append(game)

    def get_train_reward() -> Callable[[], List[int]]:
        nonlocal games
        reward = []
        for game in games:
            reward.append(game.get_result()[0])

        return reward

    return context, assembler, get_train_reward, is_client
コード例 #10
0
def create_evaluation_environment(
    seed_generator: Iterator[int],
    game_params: GameParams,
    eval_params: EvalParams,
    current_batch_size: int = None,
    pure_mcts_eval: bool = False,
    pure_mcts_opponent: bool = True,
    num_evaluated_games: int = 0
) -> Tuple[
    tube.Context,
    Optional[tube.DataChannel],
    Optional[tube.DataChannel],
    Callable[[], List[int]],
]:
    num_game = eval_params.num_game_eval
    num_actor_eval = eval_params.num_actor_eval
    num_rollouts_eval = eval_params.num_rollouts_eval
    num_actor_opponent = eval_params.num_actor_opponent
    num_rollouts_opponent = eval_params.num_rollouts_opponent
    first_hand = []
    second_hand = []
    games = []

    context = tube.Context()
    actor_channel_eval = (
        None
        if pure_mcts_eval
        else tube.DataChannel("act_eval", num_game * num_actor_eval, 1)
    )
    actor_channel_opponent = (
        None
        if pure_mcts_opponent
        else tube.DataChannel("act_opponent", num_game * num_actor_opponent, 1)
    )
    for game_no in range(current_batch_size if current_batch_size else num_game):
        game = create_game(
            game_params, num_episode=1, seed=next(seed_generator), eval_mode=True
        )
        player = create_player(
            seed_generator=seed_generator,
            game=game,
            player="mcts",
            num_actor=num_actor_eval,
            num_rollouts=num_rollouts_eval,
            pure_mcts=pure_mcts_eval,
            actor_channel=actor_channel_eval,
            model_manager=None,
            human_mode=False,
            sample_before_step_idx=8,
            randomized_rollouts=False,
            sampling_mcts=False,
        )
        if game.is_one_player_game():
            game.add_eval_player(player)
            first_hand.append(game)
        else:
            opponent = create_player(
                seed_generator=seed_generator,
                game=game,
                player="mcts",
                num_actor=num_actor_opponent,
                num_rollouts=num_rollouts_opponent,
                pure_mcts=pure_mcts_opponent,
                actor_channel=actor_channel_opponent,
                model_manager=None,
                human_mode=False,
                sample_before_step_idx=8,
                randomized_rollouts=False,
                sampling_mcts=False,
            )
            game_id = num_evaluated_games + game_no
            if player_moves_first(game_id, num_game):
                game.add_eval_player(player)
                game.add_eval_player(opponent)
                first_hand.append(game)
            else:
                game.add_eval_player(opponent)
                game.add_eval_player(player)
                second_hand.append(game)

        context.push_env_thread(game)
        games.append(game)

    def get_eval_reward():
        nonlocal first_hand, second_hand
        reward = []
        for hand in first_hand:
            reward.append(hand.get_result()[0])
        for hand in second_hand:
            reward.append(hand.get_result()[1])
        return reward

    return context, actor_channel_eval, actor_channel_opponent, get_eval_reward
コード例 #11
0
ファイル: human.py プロジェクト: facebookincubator/Polygames
def create_human_environment(
    seed_generator: Iterator[int], game_params: GameParams,
    simulation_params: SimulationParams, execution_params: ExecutionParams,
    pure_mcts: bool, model
) -> Tuple[tube.Context, Optional[tube.DataChannel], Callable[[], int]]:
    human_first = execution_params.human_first
    time_ratio = execution_params.time_ratio
    total_time = execution_params.total_time
    context = tube.Context()
    actor_channel = (None if pure_mcts else tube.DataChannel(
        "act", simulation_params.num_actor, 1))
    rnn_state_shape = []
    if model is not None and hasattr(model,
                                     "rnn_cells") and model.rnn_cells > 0:
        rnn_state_shape = [model.rnn_cells, model.rnn_channels]
    rnn_state_size = 0
    if len(rnn_state_shape) >= 2:
        rnn_state_size = rnn_state_shape[0] * rnn_state_shape[1]
    logit_value = getattr(model, "logit_value", False)
    game = create_game(
        game_params,
        num_episode=1,
        seed=next(seed_generator),
        eval_mode=True,
        per_thread_batchsize=0,
        rewind=simulation_params.rewind,
        predict_end_state=game_params.predict_end_state,
        predict_n_states=game_params.predict_n_states,
    )
    player = create_player(
        seed_generator=seed_generator,
        game=game,
        player="mcts",
        num_actor=simulation_params.num_actor,
        num_rollouts=simulation_params.num_rollouts,
        pure_mcts=pure_mcts,
        actor_channel=actor_channel,
        model_manager=None,
        human_mode=True,
        total_time=total_time,
        time_ratio=time_ratio,
        sample_before_step_idx=80,
        randomized_rollouts=False,
        sampling_mcts=False,
        rnn_state_shape=rnn_state_shape,
        rnn_seqlen=execution_params.rnn_seqlen,
        logit_value=logit_value,
    )
    human_player = polygames.HumanPlayer()
    if game.is_one_player_game():
        game.add_human_player(human_player)
    else:
        if human_first:
            game.add_human_player(human_player)
            game.add_eval_player(player)
        else:
            game.add_eval_player(player)
            game.add_human_player(human_player)

    context.push_env_thread(game)

    def get_result_for_human_player():
        nonlocal game, human_first
        return game.get_result()[not human_first]

    return context, actor_channel, get_result_for_human_player
コード例 #12
0
ファイル: simple_mask_search.py プロジェクト: apjacob/minirts
# Copyright (c) Facebook, Inc. and its affiliates.
コード例 #13
0
ファイル: sa_mask_search.py プロジェクト: apjacob/minirts
# Copyright (c) Facebook, Inc. and its affiliates.
コード例 #14
0
# Copyright (c) Facebook, Inc. and its affiliates.