Exemplo n.º 1
0
def train(agent, args, env, saver):
    """
    Train the agent.
    """

    logger = Logger(agent, args, saver.save_dir)

    # Pre-fill the Replay Buffer
    agent.initialize_memory(args.pretrain, env)

    #Begin training loop
    for episode in range(1, args.num_episodes + 1):
        # Begin each episode with a clean environment
        env.reset()
        # Get initial state
        states = env.states
        # Gather experience until done or max_steps is reached
        for t in range(args.max_steps):
            actions = agent.act(states)
            next_states, rewards, dones = env.step(actions)
            agent.step(states, actions, rewards, next_states)
            states = next_states

            logger.log(rewards, agent)
            if np.any(dones):
                break

        saver.save_checkpoint(agent, args.save_every)
        agent.new_episode()
        logger.step(episode, agent)

    env.close()
    saver.save_final(agent)
    logger.graph()
    return
Exemplo n.º 2
0
def eval(agent, args, env):
    """
    Evaluate the performance of an agent using a saved weights file.
    """

    logger = Logger(agent, args)

    #Begin evaluation loop
    for episode in range(1, args.num_episodes + 1):
        # Begin each episode with a clean environment
        env.reset()
        # Get initial state
        states = env.states
        # Gather experience until done or max_steps is reached
        for t in range(args.max_steps):
            actions = agent.act(states, eval=True)
            next_states, rewards, dones = env.step(actions)
            states = next_states

            logger.log(rewards, agent)
            if np.any(dones):
                break

        agent.new_episode()
        logger.step(episode)

    env.close()
    return
Exemplo n.º 3
0
def notebook_eval_agent(args, env, filename, num_eps=2):
    eval_agent = DQN_Agent(env.state_size, env.action_size, args)
    eval_saver = Saver(eval_agent.framework, eval_agent, args.save_dir,
                       filename)
    args.eval = True
    logger = Logger(eval_agent, args)
    for episode in range(3):
        env.reset()
        state = env.state
        for t in range(args.max_steps):
            action = eval_agent.act(state)
            next_state, reward, done = env.step(action)
            state = next_state
            logger.log(reward, eval_agent)
            if done:
                break
        eval_agent.new_episode()
        logger.step(episode)
Exemplo n.º 4
0
def train(agent, args, env, saver):
    """
    Train the agent.
    """

    logger = Logger(agent, args, saver.save_dir, log_every=50)

    # Pre-fill the Replay Buffer
    agent.initialize_memory(args.pretrain, env)

    #Begin training loop
    for episode in range(1, args.num_episodes + 1):
        # Begin each episode with a clean environment
        done = False
        env.reset()
        # Get initial state
        state = env.state
        # Gather experience until done or max_steps is reached
        while not done:
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            if done:
                next_state = None
            agent.step(state, action, reward, next_state)
            state = next_state

            logger.log(reward, agent)

        saver.save_checkpoint(agent, args.save_every)
        agent.new_episode()
        logger.step(episode, agent.epsilon)

    env.close()
    saver.save_final(agent)
    logger.graph()
    return