Пример #1
0
def run(agent: Agent,
        environment: dm_env.Environment,
        num_episodes: int,
        results_dir: str = 'res/default.pkl') -> None:
    '''
    Runs an agent on an enviroment.

    Args:
        agent: The agent to train and evaluate.
        environment: The environment to train on.
        num_episodes: Number of episodes to train for.
        verbose: Whether to also log to terminal.
    '''

    for episode in range(num_episodes):
        #Run an episode.
        timestep = environment.reset()

        while not timestep.last():
            action = agent.select_action(timestep)
            print(action)
            new_timestep = environment.step(action)

            # Pass the (s, a, r, s')info to the agent.
            agent.update(timestep, action, new_timestep)

            # update timestep
            timestep = new_timestep

        if (episode + 1) % 100 == 0:
            print("Episode %d success." % (episode + 1))

        if True:
            torch.save(getattr(agent, '_network'), results_dir)
Пример #2
0
def run(agent: base.Agent,
        environment: dm_env.Environment,
        num_episodes: int,
        verbose: bool = False) -> None:
  """Runs an agent on an environment.

  Note that for bsuite environments, logging is handled internally.

  Args:
    agent: The agent to train and evaluate.
    environment: The environment to train on.
    num_episodes: Number of episodes to train for.
    verbose: Whether to also log to terminal.
  """

  if verbose:
    environment = terminal_logging.wrap_environment(
        environment, log_every=True)  # pytype: disable=wrong-arg-types

  for _ in range(num_episodes):
    # Run an episode.
    timestep = environment.reset()
    while not timestep.last():
      # Generate an action from the agent's policy.
      action = agent.select_action(timestep)

      # Step the environment.
      new_timestep = environment.step(action)

      # Tell the agent about what just happened.
      agent.update(timestep, action, new_timestep)

      # Book-keeping.
      timestep = new_timestep
Пример #3
0
def run_episode(agent: Agent,
                env: Environment,
                action_repeat: int = 1,
                update: bool = False):
    start_time = time()
    episode_steps = 0
    episode_return = 0
    timestep = env.reset()
    agent.observe_first(timestep)
    while not timestep.last():
        action = agent.select_action(timestep.observation)
        for _ in range(action_repeat):
            timestep = env.step(action)
            agent.observe(action, next_timestep=timestep)
            episode_steps += 1
            episode_return += timestep.reward
            if update:
                agent.update()
            if timestep.last():
                break

    steps_per_second = episode_steps / (time() - start_time)
    result = {
        'episode_length': episode_steps,
        'episode_return': episode_return,
        'steps_per_second': steps_per_second,
    }
    return result
Пример #4
0
def run(
    agent: base.Agent,
    env: dm_env.Environment,
    num_episodes: int,
    eval_mode: bool = False,
) -> base.Agent:
    wandb.init(project="dqn")
    logging.info(
        "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}"
        .format("evaluating" if eval_mode else "training", agent, env,
                num_episodes))
    for episode in range(num_episodes):
        print(
            "Starting episode number {}/{}\t\t\t".format(
                episode, num_episodes - 1),
            end="\r",
        )
        wandb.log({"Episode": episode})
        # initialise environment
        timestep = env.reset()
        while not timestep.last():
            # policy
            action = agent.select_action(timestep)
            # step environment
            new_timestep = env.step(tuple(action))
            wandb.log({"Reward": new_timestep.reward})
            # update
            if not eval_mode:
                loss = agent.update(timestep, action, new_timestep)
                if loss is not None:
                    wandb.log({"Bellman MSE": float(loss)})
                wandb.log({"Iteration": agent.iteration})
            # prepare next
            timestep = new_timestep
    return agent
Пример #5
0
 def observe(
     self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: int
 ) -> dm_env.TimeStep:
     #  iterate over the number of steps
     for t in range(self.hparams.n_steps):
         #  get new MDP state
         new_timestep = env.step(action)
         #  store transition into the replay buffer
         self.memory.add(timestep, action, new_timestep, preprocess=self.preprocess)
         timestep = new_timestep
     return timestep
Пример #6
0
def run_loop(
    agent: Agent,
    environment: dm_env.Environment,
    max_steps_per_episode: int = 0,
    yield_before_reset: bool = False,
) -> Iterable[Tuple[dm_env.Environment, Optional[dm_env.TimeStep], Agent,
                    Optional[Action]]]:
    """Repeatedly alternates step calls on environment and agent.

    At time `t`, `t + 1` environment timesteps and `t + 1` agent steps have been
    seen in the current episode. `t` resets to `0` for the next episode.

    Args:
      agent: Agent to be run, has methods `step(timestep)` and `reset()`.
      environment: Environment to run, has methods `step(action)` and `reset()`.
      max_steps_per_episode: If positive, when time t reaches this value within an
        episode, the episode is truncated.
      yield_before_reset: Whether to additionally yield `(environment, None,
        agent, None)` before the agent and environment is reset at the start of
        each episode.

    Yields:
      Tuple `(environment, timestep_t, agent, a_t)` where
      `a_t = agent.step(timestep_t)`.
    """
    while True:  # For each episode.
        if yield_before_reset:
            yield environment, None, agent, None,

        t = 0
        agent.reset()
        timestep_t = environment.reset()  # timestep_0.

        while True:  # For each step in the current episode.
            a_t = agent.step(timestep_t)
            yield environment, timestep_t, agent, a_t

            # Update t after one environment step and agent step and relabel.
            t += 1
            a_tm1 = a_t
            timestep_t = environment.step(a_tm1)

            if max_steps_per_episode > 0 and t >= max_steps_per_episode:
                assert t == max_steps_per_episode
                timestep_t = timestep_t._replace(
                    step_type=dm_env.StepType.LAST)

            if timestep_t.last():
                unused_a_t = agent.step(
                    timestep_t)  # Extra agent step, action ignored.
                yield environment, timestep_t, agent, None
                break