Exemplo n.º 1
0
def create_agent(environment, obs_stacker, agent_type='DQN'):
    """Creates the Hanabi agent.

  Args:
    environment: The environment.
    obs_stacker: Observation stacker object.
    agent_type: str, type of agent to construct.

  Returns:
    An agent for playing Hanabi.

  Raises:
    ValueError: if an unknown agent type is requested.
  """
    if agent_type == 'DQN':
        return dqn_agent.DQNAgent(
            observation_size=obs_stacker.observation_size(),
            num_actions=environment.num_moves(),
            num_players=environment.players)
    elif agent_type == 'Rainbow':
        return rainbow_agent.RainbowAgent(
            observation_size=obs_stacker.observation_size(),
            num_actions=environment.num_moves(),
            num_players=environment.players)
    else:
        raise ValueError(
            'Expected valid agent_type, got {}'.format(agent_type))
Exemplo n.º 2
0
def create_adhoc_team(environment, obs_stacker, team_no, rl_shared_model=True):
    """

    :param environment:
    :param obs_stacker:
    :param team:
    :return:
    """
    agent_repo = ['DQN', 'SimpleAgent', 'Rainbow']

    if team_no == 1:
        team = {"Rainbow": [0], 'SimpleAgent': [1, 2, 3], }
    elif team_no == 2:
        team = {'Rainbow': [0, 1], 'SimpleAgent': [2, 3]}
    elif team_no == 3:
        team = {'Rainbow': [0, 1, 2], 'SimpleAgent': [3]}
    else:
        print("No valid team number defined!!")
        return None
    # empty list for agents
    # check first if each position is only listed once
    set_pos = set()

    if not all(agent in agent_repo for agent in team.keys()):
        print("Agent doesnt exist!")
        return None

    # quick check if team is defined correctly
    for positions in team.values():
        for pos in positions:
            # check if team is correctly defined, therefore if
            if pos in set_pos or pos not in range(5):
                print("TEAM NOT CORRECTLY DEFINED! CHECK positions")
                return None
            else:
                set_pos.add(pos)

    list_pos = list(set_pos)
    list_pos.sort()
    # check if list is sequential
    if not all(a == b for a, b in enumerate(list_pos)):
        print('Index positions are not sequential')
        return None

    # create a dictionary for team, which will be returned. for each position
    agent_list = [0] * len(list_pos)
    for agent_type in team:

        agent = None
        for pos in team[agent_type]:
            if agent_type == 'DQN':
                if dqn_agent.DQNAgent.is_rl_agent() and rl_shared_model \
                        and agent is not None:
                    agent_list[pos] = agent
                else:
                    agent = dqn_agent.DQNAgent(
                        observation_size=obs_stacker.observation_size(),
                        num_actions=environment.num_moves(),
                        num_players=team_no)
                    agent_list[pos] = agent

            elif agent_type == 'Rainbow':
                if dqn_agent.DQNAgent.is_rl_agent() and rl_shared_model \
                        and agent is not None:
                    agent_list[pos] = agent
                else:
                    agent = rainbow_agent.RainbowAgent(
                        observation_size=obs_stacker.observation_size(),
                        num_actions=environment.num_moves(),
                        num_players=team_no)
                    agent_list[pos] = agent

            elif agent_type == "SimpleAgent":
                if dqn_agent.DQNAgent.is_rl_agent() and rl_shared_model \
                        and agent is not None:
                    agent_list[pos] = agent
                else:
                    agent = rule_based_agent.RuleBasedAgent(
                        players=environment.players)
                    agent_list[pos] = agent
            else:
                raise ValueError('Expected valid agent_type, got {}'.format(agent_type))

    return agent_list