Exemple #1
0
    async def observe(self, stream):
        logger.debug('DotaService::observe()')

        request = await stream.recv_message()
        team_id = request.team_id

        queue = self.dota_game.worldstate_queues[team_id]

        if queue.empty():
            try:
                data = await asyncio.wait_for(queue.get(), timeout=self.observe_timeout)
            except (asyncio.TimeoutError, asyncio.CancelledError):
                # A timeout probably means the game is done
                winstate = await self.dota_game.get_final_state_from_log()
                await stream.send_message(Observation(
                    status=self.END_STATES[winstate],
                    team_id=team_id,
                    ))
                return
        else:
            while not queue.empty():
                data = queue.get_nowait()

        # Return the reponse.
        await stream.send_message(Observation(
            status=Status.Value('OK'),
            world_state=data,
            team_id=team_id,
            ))
Exemple #2
0
    async def observe(self, stream):
        logger.debug('DotaService::observe()')

        request = await stream.recv_message()
        team_id = request.team_id

        queue = self.dota_game.worldstate_queues[team_id]

        try:
            data = await asyncio.wait_for(queue.get(),
                                          timeout=self.observe_timeout)
        except (asyncio.TimeoutError, asyncio.CancelledError):
            # A timeout probably means the game is done
            winstate = await self.dota_game.get_final_state_from_log()
            await stream.send_message(
                Observation(
                    status=self.END_STATES[winstate],
                    team_id=team_id,
                ))
            return

        # Make sure indeed the queue is empty and we're entirely in sync.
        assert queue.qsize() == 0

        # Return the reponse.
        await stream.send_message(
            Observation(
                status=Status.Value('OK'),
                world_state=data,
                team_id=team_id,
            ))
Exemple #3
0
    async def play(self, config, game_id):
        logger.info('Starting game.')

        # Use the latest weights by default.
        use_latest_weights = {TEAM_RADIANT: True, TEAM_DIRE: True}
        if random.random() > self.latest_weights_prob:
            # Randomly pick the ream that will use the old weights.
            old_model_team = random.choice([TEAM_RADIANT, TEAM_DIRE])
            use_latest_weights[old_model_team] = False

        drawing = Drawing(
        )  # TODO(tzaman): drawing should include include what's visible to the player

        # Reset and obtain the initial observation. This dictates who we are controlling,
        # this is done before the player definition, because there might be humand playing
        # that take up bot positions.
        response = await asyncio.wait_for(self.dota_service.reset(config),
                                          timeout=120)

        player_request = config.hero_picks
        players_response = response.players  # Lists all human and bot players.
        players = {TEAM_RADIANT: [], TEAM_DIRE: []}
        for p_req, p_res in zip(player_request, players_response):
            assert p_req.team_id == p_req.team_id  # TODO(tzaman): more tests?
            if p_res.is_bot and p_req.control_mode == HERO_CONTROL_MODE_CONTROLLED:
                player = Player(
                    game_id=game_id,
                    player_id=p_res.id,
                    team_id=p_res.team_id,
                    hero=p_res.hero,
                    experience_channel=self.experience_channel,
                    use_latest_weights=use_latest_weights[p_res.team_id],
                    drawing=drawing,
                    validation=self.validation,
                )
                players[p_res.team_id].append(player)

        prev_obs = {
            TEAM_RADIANT: response.world_state_radiant,
            TEAM_DIRE: response.world_state_dire,
        }
        done = False
        step = 0
        dota_time = -float('Inf')
        end_state = None
        while dota_time < self.max_dota_time:
            reward_sum_step = {TEAM_RADIANT: 0, TEAM_DIRE: 0}
            for team_id in [TEAM_RADIANT, TEAM_DIRE]:
                logger.debug('\ndota_time={:.2f}, team={}'.format(
                    dota_time, team_id))

                response = await self.dota_service.observe(
                    ObserveConfig(team_id=team_id))
                if response.status != Status.Value('OK'):
                    end_state = response.status
                    done = True
                    break
                obs = response.world_state
                dota_time = obs.dota_time

                # We not loop over each player in this team and get each players action.
                actions = []
                for player in players[team_id]:
                    player.compute_reward(prev_obs=prev_obs[team_id], obs=obs)
                    reward_sum_step[team_id] += sum(
                        player.rewards[-1].values())
                    with torch.no_grad():
                        actions_player = player.obs_to_actions(obs=obs)
                    actions.extend(actions_player)

                actions_pb = CMsgBotWorldState.Actions(actions=actions)
                actions_pb.dota_time = obs.dota_time

                _ = await self.dota_service.act(
                    Actions(actions=actions_pb, team_id=team_id))

                prev_obs[team_id] = obs

            if not self.validation:
                # Subtract eachothers rewards
                for team_id in [TEAM_RADIANT, TEAM_DIRE]:
                    for player in players[team_id]:
                        player.rewards[-1]['enemy'] = -reward_sum_step[
                            OPPOSITE_TEAM[team_id]]

                for player in [*players[TEAM_RADIANT], *players[TEAM_DIRE]]:
                    if player.steps_queued > 0 and player.steps_queued % self.rollout_size == 0:
                        await player.rollout()

            if done:
                break

        if end_state in [
                Status.Value('RESOURCE_EXHAUSTED'),
                Status.Value('FAILED_PRECONDITION'),
                Status.Value('OUT_OF_RANGE')
        ]:
            # Bad end state. We don't want to roll this one out.
            logger.warning(
                'Bad end state `{}`, not rolling out game (dota_time={})'.
                format(end_state, dota_time))
            return
        # drawing.save(stem=game_id)  # HACK

        # Finish (e.g. final rollout or send validation metrics).
        for player in [*players[TEAM_RADIANT], *players[TEAM_DIRE]]:
            player.process_endstate(end_state)
            await player.finish()

        # TODO(tzaman): the worldstate ends when game is over. the worldstate doesn't have info
        # about who won the game: so we need to get info from that somehow

        logger.info('Game finished.')
Exemple #4
0
class Player:

    END_STATUS_TO_TEAM = {
        Status.Value('RADIANT_WIN'): TEAM_RADIANT,
        Status.Value('DIRE_WIN'): TEAM_DIRE,
    }

    def __init__(self, game_id, player_id, team_id, hero, experience_channel,
                 use_latest_weights, drawing, validation):
        self.game_id = game_id
        self.player_id = player_id
        self.team_id = team_id
        self.hero = hero
        self.experience_channel = experience_channel
        self.use_latest_weights = use_latest_weights

        self.policy_inputs = []
        self.actions = []
        self.selected_heads_mask = []
        self.rewards = []
        self.drawing = drawing
        self.validation = validation

        self.creeps_had_spawned = False
        self.prev_level = 0

        use_synced_weights = use_latest_weights and not self.validation

        if use_synced_weights:
            # This will actually use the latest policy, that is even updated while the agent is playing.
            self.policy = weight_store.latest_policy
        else:  # Use non-synchronized weights
            if self.validation or use_latest_weights:
                # Use the latest weights for validation
                version, state_dict = weight_store.latest_weights()
            else:
                # Use the oldest weights.
                version, state_dict = weight_store.oldest_weights()
            self.policy = Policy()
            self.policy.load_state_dict(state_dict, strict=True)
            self.policy.weight_version = version
            self.policy.eval()  # Set to evaluation mode.
        self.hidden = self.policy.init_hidden()

        logger.info('Player {} using weights version {}'.format(
            self.player_id, self.policy.weight_version))

    def summed_subrewards(self):
        reward_counter = Counter()
        for r in self.rewards:
            reward_counter.update(r)
        return dict(reward_counter)

    def print_reward_summary(self):
        subrewards = self.summed_subrewards()
        reward_sum = sum(subrewards.values())
        logger.info('Player {} reward sum: {:.2f} subrewards:\n{}'.format(
            self.player_id, reward_sum, pformat(subrewards)))

    def process_endstate(self, end_state):
        # The end-state adds rewards to the last reward.
        if not self.rewards:
            return
        if end_state in self.END_STATUS_TO_TEAM.keys():
            if self.team_id == self.END_STATUS_TO_TEAM[end_state]:
                self.rewards[-1]['win'] = 1
            else:
                self.rewards[-1]['win'] = -1
            return

        # Add a negative win reward, because we did not have a clear winner.
        self.rewards[-1]['win'] = -0.25

    @staticmethod
    def pack_observations(inputs):
        """Convert the list-of-dicts into a dict with a single tensor per input for the sequence."""
        d = {key: [] for key in Policy.INPUT_KEYS}
        for inp in inputs:  # go over steps: (list of dicts)
            for k, v in inp.items():  # go over each input in the step (dict)
                d[k].append(v)

        # Pack it up
        for k, v in d.items():
            # Concatenate together all inputs into a single tensor.
            # We formerly padded this instead of stacking, but that presented issues keeping track
            # of the chosen action ids related to units.
            d[k] = torch.stack(v)
        return d

    @staticmethod
    def pack_rewards(inputs):
        """Pack a list or reward dicts into a dense 2D tensor"""
        t = np.zeros([len(inputs), len(REWARD_KEYS)], dtype=np.float32)
        for i, reward in enumerate(inputs):
            for ir, key in enumerate(REWARD_KEYS):
                t[i, ir] = reward[key]
        return t

    @staticmethod
    def pack_actions(inputs):
        data = {key: [] for key in Policy.ACTION_OUTPUT_COUNTS.keys()}
        for inp in inputs:
            inp = Policy.flatten_selections(inputs=inp)
            for key in data:
                data[key].append(inp[key])
        for k, v in data.items():
            data[k] = torch.stack(v)
        return data

    @staticmethod
    def pack_masks(inputs):
        data = {key: [] for key in Policy.ACTION_OUTPUT_COUNTS.keys()}
        for inp in inputs:
            for key in data:
                data[key].append(inp[key])
        for k, v in data.items():
            # Concatenate over sequence axis and remove batch axis
            data[k] = torch.cat(v, dim=1).squeeze(0)
        return data

    def _send_experience_rmq(self):
        logger.debug('_send_experience_rmq')

        # Pack all the policy inputs into dense tensors
        observations = self.pack_observations(inputs=self.policy_inputs)
        masks = self.pack_masks(self.selected_heads_mask)
        actions = self.pack_actions(self.actions)
        rewards = self.pack_rewards(inputs=self.rewards)

        data = pickle.dumps({
            'game_id': self.game_id,
            'team_id': self.team_id,
            'player_id': self.player_id,
            'weight_version': self.policy.weight_version,
            'canvas': self.drawing.canvas,
            'observations': observations,
            'masks': masks,
            'actions': actions,
            'rewards': rewards,
        })
        self.experience_channel.basic_publish(
            exchange='', routing_key=EXPERIENCE_QUEUE_NAME, body=data)

    @property
    def steps_queued(self):
        return len(self.rewards)

    def write_validation(self):
        it = self.policy.weight_version
        writer.add_image('game/canvas',
                         self.drawing.canvas,
                         it,
                         dataformats='HWC')
        writer.add_scalar('game/steps', self.steps_queued, it)
        subrewards = self.summed_subrewards()
        reward_sum = sum(subrewards.values())
        writer.add_scalar('game/rewards_sum', reward_sum, it)
        for key, reward in subrewards.items():
            writer.add_scalar('game/rewards_{}'.format(key), reward, it)
        # Upload events to GCS
        writer.file_writer.flush()  # Flush before uploading
        events_filename = events_filename_from_writer(writer)
        blob = gcs_bucket().blob(events_filename)
        blob.upload_from_filename(filename=events_filename)

    async def finish(self):
        if self.validation:
            self.write_validation()
        else:
            await self.rollout()

    async def rollout(self):
        logger.info('Player {} rollout, len={}'.format(self.player_id,
                                                       self.steps_queued))

        if not self.rewards:
            logger.info('nothing to roll out.')
            return

        self.print_reward_summary()

        if self.use_latest_weights:
            self._send_experience_rmq()
        else:
            logger.info('Not using latest weights: not rolling out.')

        # Reset states.
        self.policy_inputs = []
        self.rewards = []
        self.actions = []
        self.selected_heads_mask = []

    @staticmethod
    def unit_separation(state, team_id):
        # Break apart the full unit-list into specific categories for allied and
        # enemy unit groups of various types so we don't have to repeatedly iterate
        # the full unit-list again.
        allied_heroes = []
        enemy_heroes = []
        allied_nonheroes = []
        enemy_nonheroes = []
        allied_creep = []
        enemy_creep = []
        allied_towers = []
        enemy_towers = []
        for unit in state.units:
            # check if allied or enemy unit
            if unit.team_id == team_id:
                if unit.unit_type == CMsgBotWorldState.UnitType.Value('HERO'):
                    allied_heroes.append(unit)
                elif unit.unit_type == CMsgBotWorldState.UnitType.Value(
                        'CREEP_HERO'):
                    allied_nonheroes.append(unit)
                elif unit.unit_type == CMsgBotWorldState.UnitType.Value(
                        'LANE_CREEP'):
                    allied_creep.append(unit)
                elif unit.unit_type == CMsgBotWorldState.UnitType.Value(
                        'TOWER'):
                    if unit.name[
                            -5:] == "1_mid":  # Only consider the mid tower for now.
                        allied_towers.append(unit)
            else:
                if unit.unit_type == CMsgBotWorldState.UnitType.Value('HERO'):
                    enemy_heroes.append(unit)
                elif unit.unit_type == CMsgBotWorldState.UnitType.Value(
                        'CREEP_HERO'):
                    enemy_nonheroes.append(unit)
                elif unit.unit_type == CMsgBotWorldState.UnitType.Value(
                        'LANE_CREEP'):
                    enemy_creep.append(unit)
                elif unit.unit_type == CMsgBotWorldState.UnitType.Value(
                        'TOWER'):
                    if unit.name[
                            -5:] == "1_mid":  # Only consider the mid tower for now.
                        enemy_towers.append(unit)

        return allied_heroes, enemy_heroes, allied_nonheroes, enemy_nonheroes, \
               allied_creep, enemy_creep, allied_towers, enemy_towers

    @staticmethod
    def unit_matrix(unit_list, hero_unit, only_self=False, max_units=16):
        # We are always inserting an 'zero' unit to make sure the policy doesn't barf
        # We can't just pad this, because we will otherwise lose track of corresponding chosen
        # actions relating to output indices. Even if we would, batching multiple sequences together
        # would then be another error prone nightmare.
        handles = torch.full([max_units], -1)
        m = torch.zeros(max_units, 12)
        i = 0
        for unit in unit_list:
            if unit.is_alive:
                if only_self:
                    if unit != hero_unit:
                        continue
                if i >= max_units:
                    break
                rel_hp = 1.0 - (unit.health / unit.health_max)
                rel_mana = 0.0
                if unit.mana_max > 0:
                    rel_mana = 1.0 - (unit.mana / unit.mana_max)
                loc_x = unit.location.x / MAP_HALF_WIDTH
                loc_y = unit.location.y / MAP_HALF_WIDTH
                loc_z = (unit.location.z / 512.) - 0.5
                distance_x = (hero_unit.location.x - unit.location.x)
                distance_y = (hero_unit.location.y - unit.location.y)
                distance = math.sqrt(distance_x**2 + distance_y**2)
                norm_distance = (distance / MAP_HALF_WIDTH) - 0.5

                # Get the direction where the unit is facing.
                facing_sin = math.sin(unit.facing * (2 * math.pi) / 360)
                facing_cos = math.cos(unit.facing * (2 * math.pi) / 360)

                # Calculates normalized boolean value [-0.5 or 0.5] of if unit is within
                # attack range of hero.
                in_attack_range = float(
                    distance <= hero_unit.attack_range) - 0.5

                # Calculates normalized boolean value [-0.5 or 0.5] of if that unit
                # is currently targeting me with right-click attacks.
                is_attacking_me = float(is_unit_attacking_unit(
                    unit, hero_unit)) - 0.5
                me_attacking_unit = float(
                    is_unit_attacking_unit(hero_unit, unit)) - 0.5

                in_ability_phase = -0.5
                for a in unit.abilities:
                    if a.is_in_ability_phase or a.is_channeling:
                        in_ability_phase = 0.5
                        break

                m[i] = (torch.tensor([
                    rel_hp, loc_x, loc_y, loc_z, norm_distance, facing_sin,
                    facing_cos, in_attack_range, is_attacking_me,
                    me_attacking_unit, rel_mana, in_ability_phase
                ]))

                # Because we are currently only attacking, check if these units are valid
                # HACK: Make a nice interface for this, per enum used?
                if unit.is_invulnerable or unit.is_attack_immune:
                    handles[i] = -1
                elif unit.team_id == OPPOSITE_TEAM[
                        hero_unit.
                        team_id] and unit.unit_type == CMsgBotWorldState.UnitType.Value(
                            'TOWER') and unit.anim_activity == 1500:
                    # Enemy tower. Due to a dota bug, the bot API can only attack towers (and move to it)
                    # when they are attacking (activity 1503; stationary is activity 1500)
                    handles[i] = -1
                elif unit.team_id == hero_unit.team_id and unit.unit_type == CMsgBotWorldState.UnitType.Value(
                        'TOWER'):
                    # Its own tower:
                    handles[i] = -1
                elif unit.team_id == hero_unit.team_id and (
                        unit.health / unit.health_max) > 0.5:
                    # Not denyable
                    handles[i] = -1
                else:
                    handles[i] = unit.handle

                i += 1
        return m, handles

    def select_action(self, world_state, hero_unit):
        dota_time_norm = world_state.dota_time / 1200.  # Normalize by 20 minutes
        creepwave_sin = math.sin(world_state.dota_time * (2. * math.pi) / 60)
        team_float = -.2 if self.team_id == TEAM_DIRE else .2

        env_state = torch.Tensor([dota_time_norm, creepwave_sin, team_float])

        # Separate units into unit-type groups for both teams
        # The goal is to iterate only once through the entire unit list
        # in the provided world-state protobuf and for further filtering
        # only iterate across the unit-type specific list of interest.
        ah, eh, anh, enh, ac, ec, at, et = self.unit_separation(
            world_state, hero_unit.team_id)

        # Process units into Tensors & Handles
        allied_heroes, allied_hero_handles = self.unit_matrix(
            unit_list=ah,
            hero_unit=hero_unit,
            only_self=True,  # For now, ignore teammates.
            max_units=1,
        )

        enemy_heroes, enemy_hero_handles = self.unit_matrix(
            unit_list=eh,
            hero_unit=hero_unit,
            max_units=5,
        )

        allied_nonheroes, allied_nonhero_handles = self.unit_matrix(
            unit_list=[*anh, *ac],
            hero_unit=hero_unit,
            max_units=16,
        )

        enemy_nonheroes, enemy_nonhero_handles = self.unit_matrix(
            unit_list=[*enh, *ec],
            hero_unit=hero_unit,
            max_units=16,
        )

        allied_towers, allied_tower_handles = self.unit_matrix(
            unit_list=at,
            hero_unit=hero_unit,
            max_units=1,
        )

        enemy_towers, enemy_tower_handles = self.unit_matrix(
            unit_list=et,
            hero_unit=hero_unit,
            max_units=1,
        )

        unit_handles = torch.cat([
            allied_hero_handles, enemy_hero_handles, allied_nonhero_handles,
            enemy_nonhero_handles, allied_tower_handles, enemy_tower_handles
        ])

        if not self.creeps_had_spawned and world_state.dota_time > 0.:
            # Check that creeps have spawned. See dotaclient/issues/15.
            # TODO(tzaman): this should be handled by DotaService.
            # self.creeps_had_spawned = bool((allied_nonhero_handles != -1).any())
            self.creeps_had_spawned = len(ac) > 0
            if not self.creeps_had_spawned:
                raise ValueError(
                    'Creeps have not spawned at timestep {}'.format(
                        world_state.dota_time))

        policy_input = {
            'env': env_state,
            'allied_heroes': allied_heroes,
            'enemy_heroes': enemy_heroes,
            'allied_nonheroes': allied_nonheroes,
            'enemy_nonheroes': enemy_nonheroes,
            'allied_towers': allied_towers,
            'enemy_towers': enemy_towers,
        }

        logger.debug('policy_input:\n' + pformat(policy_input))

        heads_logits, value, self.hidden = self.policy.single(
            **policy_input, hidden=self.hidden)

        logger.debug('heads_logits:\n' + pformat(heads_logits))
        logger.debug('value={}'.format(value))

        # Get valid actions. This mask contains all viable actions.
        action_masks = Policy.action_masks(player_unit=hero_unit,
                                           unit_handles=unit_handles)
        logger.debug('action_masks:\n' + pformat(action_masks))

        # From the heads logits and their masks, select the actions.
        action_dict = Policy.select_actions(heads_logits=heads_logits,
                                            masks=action_masks)
        logger.debug('action_dict:\n' + pformat(action_dict))

        # Given the action selections, get the head mask.
        head_masks = Policy.head_masks(selections=action_dict)
        logger.debug('head_masks:\n' + pformat(head_masks))

        # Combine the head mask and the selection mask, to get all relevant probabilities of the
        # current action.
        selected_heads_mask = {
            key: head_masks[key] & action_masks[key]
            for key in head_masks
        }
        logger.debug('selected_heads_mask:\n' + pformat(selected_heads_mask))

        return policy_input, action_dict, selected_heads_mask, unit_handles

    def action_to_pb(self, action_dict, state, unit_handles):
        # TODO(tzaman): Recrease the scope of this function. Make it a converter only.
        hero_unit = get_unit(state, player_id=self.player_id)
        action_pb = CMsgBotWorldState.Action()
        action_pb.actionDelay = 0  # action_dict['delay'] * DELAY_ENUM_TO_STEP
        action_enum = action_dict['enum']

        if action_enum == 0:
            action_pb.actionType = CMsgBotWorldState.Action.Type.Value(
                'DOTA_UNIT_ORDER_NONE')
        elif action_enum == 1:
            action_pb.actionType = CMsgBotWorldState.Action.Type.Value(
                'DOTA_UNIT_ORDER_MOVE_DIRECTLY')
            m = CMsgBotWorldState.Action.MoveToLocation()
            hero_location = hero_unit.location
            m.location.x = hero_location.x + Policy.MOVE_ENUMS[
                action_dict['x']]
            m.location.y = hero_location.y + Policy.MOVE_ENUMS[
                action_dict['y']]
            m.location.z = 0
            action_pb.moveDirectly.CopyFrom(m)
        elif action_enum == 2:
            action_pb.actionType = CMsgBotWorldState.Action.Type.Value(
                'DOTA_UNIT_ORDER_ATTACK_TARGET')
            m = CMsgBotWorldState.Action.AttackTarget()
            if 'target_unit' in action_dict:
                m.target = unit_handles[action_dict['target_unit']]
            else:
                m.target = -1
            m.once = True
            action_pb.attackTarget.CopyFrom(m)
        elif action_enum == 3:
            action_pb = CMsgBotWorldState.Action()
            action_pb.actionType = CMsgBotWorldState.Action.Type.Value(
                'DOTA_UNIT_ORDER_CAST_NO_TARGET')
            action_pb.cast.abilitySlot = action_dict['ability']
        else:
            raise ValueError("unknown action {}".format(action_enum))
        action_pb.player = self.player_id
        return action_pb

    def train_ability(self, hero_unit):
        # Check if we leveled up
        leveled_up = hero_unit.level > self.prev_level
        if leveled_up:
            self.prev_level = hero_unit.level
            # Just try to level up the first ability.
            action_pb = CMsgBotWorldState.Action()
            action_pb.actionType = CMsgBotWorldState.Action.Type.Value(
                'DOTA_UNIT_ORDER_TRAIN_ABILITY')
            action_pb.player = self.player_id
            action_pb.trainAbility.ability = "nevermore_shadowraze1"
            return action_pb
        return None

    def obs_to_actions(self, obs):
        actions = []
        hero_unit = get_unit(state=obs, player_id=self.player_id)

        policy_input, action_dict, selected_heads_mask, unit_handles = self.select_action(
            world_state=obs,
            hero_unit=hero_unit,
        )

        self.policy_inputs.append(policy_input)
        self.actions.append(action_dict)
        self.selected_heads_mask.append(selected_heads_mask)
        logger.debug('action:\n' + pformat(action_dict))

        action_pb = self.action_to_pb(action_dict=action_dict,
                                      state=obs,
                                      unit_handles=unit_handles)
        actions.append(action_pb)

        level_pb = self.train_ability(hero_unit)
        if level_pb is not None:
            actions.append(level_pb)

        return actions

    def compute_reward(self, prev_obs, obs):
        # Draw.
        self.drawing.step(state=obs,
                          team_id=self.team_id,
                          player_id=self.player_id)

        reward = get_reward(prev_obs=prev_obs,
                            obs=obs,
                            player_id=self.player_id)
        self.rewards.append(reward)
Exemple #5
0
class DotaService():

    END_STATES = {
        None: Status.Value('RESOURCE_EXHAUSTED'),
        TEAM_RADIANT: Status.Value('RADIANT_WIN'),
        TEAM_DIRE: Status.Value('DIRE_WIN'),
    }

    def __init__(self, dota_path, action_folder, remove_logs):
        self.dota_path = dota_path
        self.action_folder = action_folder
        self.remove_logs = remove_logs

        # Initial assertions.
        verify_game_path(self.dota_path)

        if not os.path.exists(self.action_folder):
            if platform == "linux" or platform == "linux2":
                raise ValueError(
                    "Action folder '{}' not found.\nYou can create a 2GB ramdisk by executing:"
                    "`mkdir /tmpfs; mount -t tmpfs -o size=2048M tmpfs /tmpfs`\n"
                    "With Docker, you can add a tmpfs adding `--mount type=tmpfs,destination=/tmpfs`"
                    " to its run command.".format(self.action_folder))
            elif platform == "darwin":
                if not os.path.exists(self.action_folder):
                    raise ValueError(
                        "Action folder '{}' not found.\nYou can create a 2GB ramdisk by executing:"
                        " `diskutil erasevolume HFS+ 'ramdisk' `hdiutil attach -nomount ram://4194304``"
                        .format(self.action_folder))

        self.dota_game = None
        super().__init__()

    @property
    def observe_timeout(self):
        if self.dota_game.host_mode == HOST_MODE_DEDICATED:
            return 10
        return 3600

    @staticmethod
    def stop_dota_pids():
        """Stop all dota processes.
        
        Stopping dota is nessecary because only one client can be active at a time. So we clean
        up anything that already existed earlier, or a (hanging) mess we might have created.
        """
        dota_pids = str.split(
            os.popen("ps -e | grep dota2 | awk '{print $1}'").read())
        for pid in dota_pids:
            try:
                os.kill(int(pid), signal.SIGKILL)
            except ProcessLookupError:
                pass

    def clean_resources(self):
        """Clean resoruces.
        
        Kill any previously running dota processes, and therefore set our status to ready.
        """
        # TODO(tzaman): Currently semi-gracefully. Can be cleaner.
        if self.dota_game is not None:
            # await self.dota_game.close()
            self.dota_game = None
        self.stop_dota_pids()

    def reset(self, config):
        """reset method.

        This method should start up the dota game and the other required services.
        """
        print('DotaService::reset()')

        logger.debug('config=\n{}'.format(config))

        self.clean_resources()

        # Create a new dota game instance.
        self.dota_game = DotaGame(
            dota_path=self.dota_path,
            action_folder=self.action_folder,
            remove_logs=self.remove_logs,
            host_timescale=config.host_timescale,
            ticks_per_observation=config.ticks_per_observation,
            hero_picks=config.hero_picks,
            game_mode=config.game_mode,
            host_mode=config.host_mode,
            game_id=config.game_id,
        )
        print(f'timescale: {self.dota_game.host_timescale}')

        # Start dota.
        self.dota_game.run()

    @staticmethod
    def players_to_pb(players):
        players_pb = []
        for player in players:
            players_pb.append(
                Player(
                    id=player['id'],
                    team_id=player['team_id'],
                    is_bot=player['is_bot'],
                    hero=player['hero'].upper(),
                ))
        return players_pb
Exemple #6
0
class DotaService(DotaServiceBase):

    END_STATES = {
        None: Status.Value('RESOURCE_EXHAUSTED'),
        TEAM_RADIANT: Status.Value('RADIANT_WIN'),
        TEAM_DIRE: Status.Value('DIRE_WIN'),
    }

    def __init__(self, dota_path, action_folder, remove_logs):
        self.dota_path = dota_path
        self.action_folder = action_folder
        self.remove_logs = remove_logs

        # Initial assertions.
        verify_game_path(self.dota_path)

        if not os.path.exists(self.action_folder):
            if platform == "linux" or platform == "linux2":
                raise ValueError(
                    "Action folder '{}' not found.\nYou can create a 2GB ramdisk by executing:"
                    "`mkdir /tmpfs; mount -t tmpfs -o size=2048M tmpfs /tmpfs`\n"
                    "With Docker, you can add a tmpfs adding `--mount type=tmpfs,destination=/tmpfs`"
                    " to its run command.".format(self.action_folder))
            elif platform == "darwin":
                if not os.path.exists(self.action_folder):
                    raise ValueError(
                        "Action folder '{}' not found.\nYou can create a 2GB ramdisk by executing:"
                        " `diskutil erasevolume HFS+ 'ramdisk' `hdiutil attach -nomount ram://4194304``"
                        .format(self.action_folder))

        self.dota_game = None
        super().__init__()

    @property
    def observe_timeout(self):
        if self.dota_game.host_mode == HOST_MODE_DEDICATED:
            return 10
        return 3600

    @staticmethod
    def stop_dota_pids():
        """Stop all dota processes.
        
        Stopping dota is nessecary because only one client can be active at a time. So we clean
        up anything that already existed earlier, or a (hanging) mess we might have created.
        """
        dota_pids = str.split(os.popen("ps -e | grep dota2 | awk '{print $1}'").read())
        for pid in dota_pids:
            try:
                os.kill(int(pid), signal.SIGKILL)
            except ProcessLookupError:
                pass
            except PermissionError:
                pass

    async def clean_resources(self):
        """Clean resoruces.
        
        Kill any previously running dota processes, and therefore set our status to ready.
        """
        # TODO(tzaman): Currently semi-gracefully. Can be cleaner.
        if self.dota_game is not None:
            # await self.dota_game.close()
            self.dota_game = None
        self.stop_dota_pids()

    async def reset(self, stream):
        """reset method.

        This method should start up the dota game and the other required services.
        """
        logger.info('DotaService::reset()')

        config = await stream.recv_message()
        logger.debug('config=\n{}'.format(config))

        await self.clean_resources()

        if config.host_mode != HOST_MODE_DEDICATED:
            # for non dedicate server wait 5s before steam recognize dota2 dead
            await asyncio.sleep(5)

        # Create a new dota game instance.
        self.dota_game = DotaGame(
            dota_path=self.dota_path,
            action_folder=self.action_folder,
            remove_logs=self.remove_logs,
            host_timescale=config.host_timescale,
            ticks_per_observation=config.ticks_per_observation,
            hero_picks=config.hero_picks,
            game_mode=config.game_mode,
            host_mode=config.host_mode,
            game_id=config.game_id,
        )

        # Start dota.
        asyncio.create_task(self.dota_game.run())

        # We first wait for the lua config. TODO(tzaman): do this in DotaGame?
        logger.debug('::reset is awaiting lua config..')
        lua_config = await self.dota_game.lua_config_future
        logger.debug('::reset: lua config received={}'.format(lua_config))

        # Cycle through the queue until its empty, then only using the latest worldstate.
        data = {TEAM_RADIANT: None, TEAM_DIRE: None}
        for team_id in self.dota_game.worldstate_queues:
            try:
                while True:
                    # Deplete the queue.
                    queue = self.dota_game.worldstate_queues[team_id]
                    data[team_id] = await asyncio.wait_for(queue.get(), timeout=0.5)
            except asyncio.TimeoutError:
                pass

        assert data[TEAM_RADIANT] is not None
        assert data[TEAM_DIRE] is not None

        #if data[TEAM_RADIANT].dota_time != data[TEAM_DIRE].dota_time:
        #    raise ValueError(
        #        'dota_time discrepancy in depleting initial worldstate queue.\n'
        #        'radiant={:.2f}, dire={:.2f}'.format(data[TEAM_RADIANT].dota_time, data[TEAM_DIRE].dota_time))

        last_dota_time = data[TEAM_RADIANT].dota_time

        # Now write the calibration file.
        config = {
            'calibration_dota_time': last_dota_time
        }
        self.dota_game.write_live_config(data=config)

        # Return the reponse
        await stream.send_message(InitialObservation(
            world_state_radiant=data[TEAM_RADIANT],
            world_state_dire=data[TEAM_DIRE],
            players=self.players_to_pb(self.dota_game.players),
        ))

    @staticmethod
    def players_to_pb(players):
        players_pb = []
        for player in players:
            players_pb.append(
                Player(
                    id=player['id'],
                    team_id=player['team_id'],
                    is_bot=player['is_bot'],
                    hero=player['hero'].upper(),
                )
            )
        return players_pb

    async def observe(self, stream):
        logger.debug('DotaService::observe()')

        request = await stream.recv_message()
        team_id = request.team_id

        queue = self.dota_game.worldstate_queues[team_id]

        if queue.empty():
            try:
                data = await asyncio.wait_for(queue.get(), timeout=self.observe_timeout)
            except (asyncio.TimeoutError, asyncio.CancelledError):
                # A timeout probably means the game is done
                winstate = await self.dota_game.get_final_state_from_log()
                await stream.send_message(Observation(
                    status=self.END_STATES[winstate],
                    team_id=team_id,
                    ))
                return
        else:
            while not queue.empty():
                data = queue.get_nowait()

        # Return the reponse.
        await stream.send_message(Observation(
            status=Status.Value('OK'),
            world_state=data,
            team_id=team_id,
            ))

    async def act(self, stream):
        logger.debug('DotaService::act()')

        request = await stream.recv_message()
        team_id = request.team_id
        actions = MessageToDict(request.actions)

        logger.debug('team_id={}, actions=\n{}'.format(team_id, pformat(actions)))

        self.dota_game.write_action(data=actions, team_id=team_id)

        # Return the reponse.
        await stream.send_message(Empty())