def create_agent(environment, obs_stacker, agent_type='DQN'): """Creates the Hanabi agent. Args: environment: The environment. obs_stacker: Observation stacker object. agent_type: str, type of agent to construct. Returns: An agent for playing Hanabi. Raises: ValueError: if an unknown agent type is requested. """ if agent_type == 'DQN': return dqn_agent.DQNAgent( observation_size=obs_stacker.observation_size(), num_actions=environment.num_moves(), num_players=environment.players) elif agent_type == 'Rainbow': return rainbow_agent.RainbowAgent( observation_size=obs_stacker.observation_size(), num_actions=environment.num_moves(), num_players=environment.players) else: raise ValueError( 'Expected valid agent_type, got {}'.format(agent_type))
def create_adhoc_team(environment, obs_stacker, team_no, rl_shared_model=True): """ :param environment: :param obs_stacker: :param team: :return: """ agent_repo = ['DQN', 'SimpleAgent', 'Rainbow'] if team_no == 1: team = {"Rainbow": [0], 'SimpleAgent': [1, 2, 3], } elif team_no == 2: team = {'Rainbow': [0, 1], 'SimpleAgent': [2, 3]} elif team_no == 3: team = {'Rainbow': [0, 1, 2], 'SimpleAgent': [3]} else: print("No valid team number defined!!") return None # empty list for agents # check first if each position is only listed once set_pos = set() if not all(agent in agent_repo for agent in team.keys()): print("Agent doesnt exist!") return None # quick check if team is defined correctly for positions in team.values(): for pos in positions: # check if team is correctly defined, therefore if if pos in set_pos or pos not in range(5): print("TEAM NOT CORRECTLY DEFINED! CHECK positions") return None else: set_pos.add(pos) list_pos = list(set_pos) list_pos.sort() # check if list is sequential if not all(a == b for a, b in enumerate(list_pos)): print('Index positions are not sequential') return None # create a dictionary for team, which will be returned. for each position agent_list = [0] * len(list_pos) for agent_type in team: agent = None for pos in team[agent_type]: if agent_type == 'DQN': if dqn_agent.DQNAgent.is_rl_agent() and rl_shared_model \ and agent is not None: agent_list[pos] = agent else: agent = dqn_agent.DQNAgent( observation_size=obs_stacker.observation_size(), num_actions=environment.num_moves(), num_players=team_no) agent_list[pos] = agent elif agent_type == 'Rainbow': if dqn_agent.DQNAgent.is_rl_agent() and rl_shared_model \ and agent is not None: agent_list[pos] = agent else: agent = rainbow_agent.RainbowAgent( observation_size=obs_stacker.observation_size(), num_actions=environment.num_moves(), num_players=team_no) agent_list[pos] = agent elif agent_type == "SimpleAgent": if dqn_agent.DQNAgent.is_rl_agent() and rl_shared_model \ and agent is not None: agent_list[pos] = agent else: agent = rule_based_agent.RuleBasedAgent( players=environment.players) agent_list[pos] = agent else: raise ValueError('Expected valid agent_type, got {}'.format(agent_type)) return agent_list