Beispiel #1
0
    async def play(self, config, game_id):
        logger.info('Starting game.')

        # Use the latest weights by default.
        use_latest_weights = {TEAM_RADIANT: True, TEAM_DIRE: True}
        if random.random() > self.latest_weights_prob:
            # Randomly pick the ream that will use the old weights.
            old_model_team = random.choice([TEAM_RADIANT, TEAM_DIRE])
            use_latest_weights[old_model_team] = False

        drawing = Drawing(
        )  # TODO(tzaman): drawing should include include what's visible to the player

        # Reset and obtain the initial observation. This dictates who we are controlling,
        # this is done before the player definition, because there might be humand playing
        # that take up bot positions.
        response = await asyncio.wait_for(self.dota_service.reset(config),
                                          timeout=120)

        player_request = config.hero_picks
        players_response = response.players  # Lists all human and bot players.
        players = {TEAM_RADIANT: [], TEAM_DIRE: []}
        for p_req, p_res in zip(player_request, players_response):
            assert p_req.team_id == p_req.team_id  # TODO(tzaman): more tests?
            if p_res.is_bot and p_req.control_mode == HERO_CONTROL_MODE_CONTROLLED:
                player = Player(
                    game_id=game_id,
                    player_id=p_res.id,
                    team_id=p_res.team_id,
                    hero=p_res.hero,
                    experience_channel=self.experience_channel,
                    use_latest_weights=use_latest_weights[p_res.team_id],
                    drawing=drawing,
                    validation=self.validation,
                )
                players[p_res.team_id].append(player)

        prev_obs = {
            TEAM_RADIANT: response.world_state_radiant,
            TEAM_DIRE: response.world_state_dire,
        }
        done = False
        step = 0
        dota_time = -float('Inf')
        end_state = None
        while dota_time < self.max_dota_time:
            reward_sum_step = {TEAM_RADIANT: 0, TEAM_DIRE: 0}
            for team_id in [TEAM_RADIANT, TEAM_DIRE]:
                logger.debug('\ndota_time={:.2f}, team={}'.format(
                    dota_time, team_id))

                response = await self.dota_service.observe(
                    ObserveConfig(team_id=team_id))
                if response.status != Status.Value('OK'):
                    end_state = response.status
                    done = True
                    break
                obs = response.world_state
                dota_time = obs.dota_time

                # We not loop over each player in this team and get each players action.
                actions = []
                for player in players[team_id]:
                    player.compute_reward(prev_obs=prev_obs[team_id], obs=obs)
                    reward_sum_step[team_id] += sum(
                        player.rewards[-1].values())
                    with torch.no_grad():
                        actions_player = player.obs_to_actions(obs=obs)
                    actions.extend(actions_player)

                actions_pb = CMsgBotWorldState.Actions(actions=actions)
                actions_pb.dota_time = obs.dota_time

                _ = await self.dota_service.act(
                    Actions(actions=actions_pb, team_id=team_id))

                prev_obs[team_id] = obs

            if not self.validation:
                # Subtract eachothers rewards
                for team_id in [TEAM_RADIANT, TEAM_DIRE]:
                    for player in players[team_id]:
                        player.rewards[-1]['enemy'] = -reward_sum_step[
                            OPPOSITE_TEAM[team_id]]

                for player in [*players[TEAM_RADIANT], *players[TEAM_DIRE]]:
                    if player.steps_queued > 0 and player.steps_queued % self.rollout_size == 0:
                        await player.rollout()

            if done:
                break

        if end_state in [
                Status.Value('RESOURCE_EXHAUSTED'),
                Status.Value('FAILED_PRECONDITION'),
                Status.Value('OUT_OF_RANGE')
        ]:
            # Bad end state. We don't want to roll this one out.
            logger.warning(
                'Bad end state `{}`, not rolling out game (dota_time={})'.
                format(end_state, dota_time))
            return
        # drawing.save(stem=game_id)  # HACK

        # Finish (e.g. final rollout or send validation metrics).
        for player in [*players[TEAM_RADIANT], *players[TEAM_DIRE]]:
            player.process_endstate(end_state)
            await player.finish()

        # TODO(tzaman): the worldstate ends when game is over. the worldstate doesn't have info
        # about who won the game: so we need to get info from that somehow

        logger.info('Game finished.')
async def main():
    dummy_action = Actions(
        actions=CMsgBotWorldState.Actions(actions=[
            CMsgBotWorldState.Action(
                actionType=CMsgBotWorldState.Action.Type.
                DOTA_UNIT_ORDER_MOVE_TO_POSITION,
                moveToLocation=CMsgBotWorldState.Action.MoveToLocation(
                    # units=[0],
                    location=CMsgBotWorldState.Vector(x=-394, y=-486, z=204)))
        ]),
        team_id=Team.TEAM_RADIANT)
    # Connect to the DotaService.
    env = DotaServiceStub(Channel('192.168.1.17', 13337))

    # Get the initial observation.
    observation = await env.reset(
        GameConfig(
            host_mode=HostMode.HOST_MODE_GUI,
            hero_picks=[
                HeroPick(
                    team_id=Team.TEAM_RADIANT,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_CONTROLLED),
                HeroPick(
                    team_id=Team.TEAM_RADIANT,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_RADIANT,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_RADIANT,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_RADIANT,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_DIRE,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_DIRE,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_DIRE,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_DIRE,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
                HeroPick(
                    team_id=Team.TEAM_DIRE,
                    hero_id=Hero.NPC_DOTA_HERO_PUDGE,
                    control_mode=HeroControlMode.HERO_CONTROL_MODE_DEFAULT),
            ],
            ticks_per_observation=30))

    for _ in range(15):
        # Sample an action from the action protobuf
        # Take an action, returning the resulting observation.

        # print(observation)
        await env.act(dummy_action)
        observation = await env.observe(
            ObserveConfig(team_id=Team.TEAM_RADIANT))
        print(".", end="")

    print()
    move_action = Actions(
        actions=CMsgBotWorldState.Actions(
            # dota_time=observation.world_state.dota_time,
            actions=[
                CMsgBotWorldState.Action(
                    actionDelay=0,
                    actionType=CMsgBotWorldState.Action.Type.
                    DOTA_UNIT_ORDER_MOVE_TO_POSITION,
                    moveToLocation=CMsgBotWorldState.Action.MoveToLocation(
                        # units=[1],  # TODO: Should really get unit ID from worldstate
                        location=CMsgBotWorldState.Vector(x=-394, y=-486,
                                                          z=0)),
                    player=0)
            ]),
        team_id=Team.TEAM_RADIANT)
    print(f"moving {move_action}")
    while True:
        # Sample an action from the action protobuf
        # Take an action, returning the resulting observation.

        # print(observation)
        await env.act(move_action)
        observation = await env.observe(
            ObserveConfig(team_id=Team.TEAM_RADIANT))