def create_agent(name, model):
    """
    Create a specific type of Snake AI agent.
    
    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.

    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == 'human':
        return HumanAgent()
    elif name == 'dqn':
        if model is None:
            raise ValueError('A model file is required for a DQN agent.')
        return DeepQNetworkAgent(model=model,
                                 memory_size=-1,
                                 num_last_frames=4)
    elif name == 'random':
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')
def create_agent(name, model, model2):
    """
    Create a specific type of Snake AI agent.

    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.

    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == 'human':
        return HumanAgent()
    elif name == 'minimaxdqn':
        return MinimaxDeepQNetworkAgent(model_1 = model, model_2 = model2, memory_size=-1, num_last_frames=4)
    elif name == 'minimaxdqnsingle':
        return MinimaxSingleDeepQNetworkAgent(model = model, memory_size=-1, num_last_frames=4)
    
    elif name == 'random':
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')
Exemple #3
0
def create_agent(name, model, env):
    """
    Create a specific type of Snake AI agent.

    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.
        env: an instance of Snake environment.
    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == "human":
        return HumanAgent()
    elif name == "dqn":
        if model is None:
            raise ValueError("A model file is required for a DQN agent.")
        return DeepQNetworkAgent(
            model=model,
            memory_size=-1,
            num_last_frames=4,
            env_shape=env.observation_shape,
            num_actions=env.num_actions,
        )
    elif name == "random":
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')