def main():
    parsed_args = parse_command_line_args(sys.argv[1:])
    num_last_frames = 4
    env = create_snake_environment(parsed_args.level)

    model = nn.Sequential(
        nn.Conv2d(num_last_frames, 16, 3),
        nn.ReLU(),
        nn.Conv2d(16, 32, 3),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(1152, 256),
        nn.ReLU(),
        nn.Linear(256, env.num_actions),
    )

    agent = DeepQNetworkAgent(
        model=model,
        env_shape=env.observation_shape,
        num_actions=env.num_actions,
        memory_size=-1,
        num_last_frames=num_last_frames,
    )
    agent.train(
        env,
        batch_size=64,
        num_episodes=parsed_args.num_episodes,
        checkpoint_freq=parsed_args.num_episodes // 10,
        discount_factor=0.95,
    )
Example #2
0
def main():
    parsed_args = parse_command_line_args(sys.argv[1:])

    env = create_snake_environment(parsed_args.level)
    model = create_dqn_model(env, num_last_frames=4)

    agent = DeepQNetworkAgent(model=model,
                              memory_size=-1,
                              num_last_frames=model.input_shape[1])
    agent.train(env,
                batch_size=64,
                num_episodes=parsed_args.num_episodes,
                checkpoint_freq=parsed_args.num_episodes // 10,
                discount_factor=0.95)
def create_agent(name, model):
    """
    Create a specific type of Snake AI agent.
    
    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.

    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == 'human':
        return HumanAgent()
    elif name == 'dqn':
        if model is None:
            raise ValueError('A model file is required for a DQN agent.')
        return DeepQNetworkAgent(model=model,
                                 memory_size=-1,
                                 num_last_frames=4)
    elif name == 'random':
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')
Example #4
0
def create_agent(name, model, env):
    """
    Create a specific type of Snake AI agent.

    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.
        env: an instance of Snake environment.
    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == "human":
        return HumanAgent()
    elif name == "dqn":
        if model is None:
            raise ValueError("A model file is required for a DQN agent.")
        return DeepQNetworkAgent(
            model=model,
            memory_size=-1,
            num_last_frames=4,
            env_shape=env.observation_shape,
            num_actions=env.num_actions,
        )
    elif name == "random":
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')
Example #5
0
def main():
    # Handle input params. load parsed args if specified by users
    # otherwise load from config file
    parsed_args = parse_command_line_args(sys.argv[1:])
    level = parsed_args.level if parsed_args.level else Config.LEVEL
    num_episodes = parsed_args.num_episodes if parsed_args.num_episodes else Config.NUM_EPISODES

    # Create a folder for data output.
    timestamp = time.strftime('%Y%m%d-%H%M%S')
    output_path = os.path.join(
        'outputs',
        str(timestamp) + '_' + Config.LEARNING_METHOD + '_' +
        str(num_episodes) + 'epsiodes_' + os.path.basename(level))
    os.makedirs(output_path)

    # dump a copy of config and env to outputs
    copy(level, output_path)
    copy("config.py", output_path)

    env = create_snake_environment(level, output_path)
    model = []
    if parsed_args.model:
        model.append(load_model(parsed_args.model))
        model.append(load_model(parsed_args.model))
    elif Config.USE_PRETRAINED_MODEL:
        model.append(load_model(Config.PRETRAINED_MODEL))
        model.append(load_model(Config.PRETRAINED_MODEL))
    else:
        model.append(
            create_dqn_model(env, num_last_frames=Config.NUM_LAST_FRAMES))
        model.append(
            create_dqn_model(env, num_last_frames=Config.NUM_LAST_FRAMES))

    agent = DeepQNetworkAgent(model=model,
                              memory_size=Config.MEMORY_SIZE,
                              num_last_frames=model[0].input_shape[1],
                              output=output_path)
    agent.train(env,
                batch_size=Config.BATCH_SIZE,
                num_episodes=num_episodes,
                checkpoint_freq=100,
                discount_factor=Config.DISCOUNT_FACTOR,
                method=Config.LEARNING_METHOD,
                multi_step=Config.MULTI_STEP_REWARD)
Example #6
0
def main():
    # Create a folder for data output.
    timestamp = time.strftime('%Y%m%d-%H%M%S')
    output_path = os.path.join('outputs', str(timestamp))
    os.makedirs(output_path)

    parsed_args = parse_command_line_args(sys.argv[1:])

    env = create_snake_environment(parsed_args.level, output_path)
    model = create_dqn_model(env, num_last_frames=4)

    agent = DeepQNetworkAgent(model=model,
                              memory_size=-1,
                              num_last_frames=model.input_shape[1],
                              output=output_path)
    agent.train(env,
                batch_size=64,
                num_episodes=parsed_args.num_episodes,
                checkpoint_freq=parsed_args.num_episodes // 10,
                discount_factor=0.95)
Example #7
0
def main():
    parsed_args = parse_command_line_args(sys.argv[1:])

    env = create_snake_environment(parsed_args.level)
    if parsed_args.attention == -1:
        model = create_dqn_model(env, num_last_frames=4)
    else:
        model = create_vin_model(env,
                                 num_last_frames=4,
                                 attention=parsed_args.attention)
    #model.load_weights('dqn-final.model')
    agent = DeepQNetworkAgent(model=model,
                              memory_size=-1,
                              num_last_frames=4,
                              attention=parsed_args.attention)
    agent.train(env,
                batch_size=64,
                num_episodes=parsed_args.num_episodes,
                checkpoint_freq=parsed_args.num_episodes // 10,
                discount_factor=0.95)
    '''agent.train(