def select_action(self, world_state, hero_unit): dota_time_norm = world_state.dota_time / 1200. # Normalize by 20 minutes creepwave_sin = math.sin(world_state.dota_time * (2. * math.pi) / 60) team_float = -.2 if self.team_id == TEAM_DIRE else .2 env_state = torch.Tensor([dota_time_norm, creepwave_sin, team_float]) # Separate units into unit-type groups for both teams # The goal is to iterate only once through the entire unit list # in the provided world-state protobuf and for further filtering # only iterate across the unit-type specific list of interest. ah, eh, anh, enh, ac, ec, at, et = self.unit_separation( world_state, hero_unit.team_id) # Process units into Tensors & Handles allied_heroes, allied_hero_handles = self.unit_matrix( unit_list=ah, hero_unit=hero_unit, only_self=True, # For now, ignore teammates. max_units=1, ) enemy_heroes, enemy_hero_handles = self.unit_matrix( unit_list=eh, hero_unit=hero_unit, max_units=5, ) allied_nonheroes, allied_nonhero_handles = self.unit_matrix( unit_list=[*anh, *ac], hero_unit=hero_unit, max_units=16, ) enemy_nonheroes, enemy_nonhero_handles = self.unit_matrix( unit_list=[*enh, *ec], hero_unit=hero_unit, max_units=16, ) allied_towers, allied_tower_handles = self.unit_matrix( unit_list=at, hero_unit=hero_unit, max_units=1, ) enemy_towers, enemy_tower_handles = self.unit_matrix( unit_list=et, hero_unit=hero_unit, max_units=1, ) unit_handles = torch.cat([ allied_hero_handles, enemy_hero_handles, allied_nonhero_handles, enemy_nonhero_handles, allied_tower_handles, enemy_tower_handles ]) if not self.creeps_had_spawned and world_state.dota_time > 0.: # Check that creeps have spawned. See dotaclient/issues/15. # TODO(tzaman): this should be handled by DotaService. # self.creeps_had_spawned = bool((allied_nonhero_handles != -1).any()) self.creeps_had_spawned = len(ac) > 0 if not self.creeps_had_spawned: raise ValueError( 'Creeps have not spawned at timestep {}'.format( world_state.dota_time)) policy_input = { 'env': env_state, 'allied_heroes': allied_heroes, 'enemy_heroes': enemy_heroes, 'allied_nonheroes': allied_nonheroes, 'enemy_nonheroes': enemy_nonheroes, 'allied_towers': allied_towers, 'enemy_towers': enemy_towers, } logger.debug('policy_input:\n' + pformat(policy_input)) heads_logits, value, self.hidden = self.policy.single( **policy_input, hidden=self.hidden) logger.debug('heads_logits:\n' + pformat(heads_logits)) logger.debug('value={}'.format(value)) # Get valid actions. This mask contains all viable actions. action_masks = Policy.action_masks(player_unit=hero_unit, unit_handles=unit_handles) logger.debug('action_masks:\n' + pformat(action_masks)) # From the heads logits and their masks, select the actions. action_dict = Policy.select_actions(heads_logits=heads_logits, masks=action_masks) logger.debug('action_dict:\n' + pformat(action_dict)) # Given the action selections, get the head mask. head_masks = Policy.head_masks(selections=action_dict) logger.debug('head_masks:\n' + pformat(head_masks)) # Combine the head mask and the selection mask, to get all relevant probabilities of the # current action. selected_heads_mask = { key: head_masks[key] & action_masks[key] for key in head_masks } logger.debug('selected_heads_mask:\n' + pformat(selected_heads_mask)) return policy_input, action_dict, selected_heads_mask, unit_handles