Esempio n. 1
0
def run(agent: Agent,
        environment: dm_env.Environment,
        num_episodes: int,
        results_dir: str = 'res/default.pkl') -> None:
    '''
    Runs an agent on an enviroment.

    Args:
        agent: The agent to train and evaluate.
        environment: The environment to train on.
        num_episodes: Number of episodes to train for.
        verbose: Whether to also log to terminal.
    '''

    for episode in range(num_episodes):
        #Run an episode.
        timestep = environment.reset()

        while not timestep.last():
            action = agent.select_action(timestep)
            print(action)
            new_timestep = environment.step(action)

            # Pass the (s, a, r, s')info to the agent.
            agent.update(timestep, action, new_timestep)

            # update timestep
            timestep = new_timestep

        if (episode + 1) % 100 == 0:
            print("Episode %d success." % (episode + 1))

        if True:
            torch.save(getattr(agent, '_network'), results_dir)
Esempio n. 2
0
def run(agent: base.Agent,
        environment: dm_env.Environment,
        num_episodes: int,
        verbose: bool = False) -> None:
  """Runs an agent on an environment.

  Note that for bsuite environments, logging is handled internally.

  Args:
    agent: The agent to train and evaluate.
    environment: The environment to train on.
    num_episodes: Number of episodes to train for.
    verbose: Whether to also log to terminal.
  """

  if verbose:
    environment = terminal_logging.wrap_environment(
        environment, log_every=True)  # pytype: disable=wrong-arg-types

  for _ in range(num_episodes):
    # Run an episode.
    timestep = environment.reset()
    while not timestep.last():
      # Generate an action from the agent's policy.
      action = agent.select_action(timestep)

      # Step the environment.
      new_timestep = environment.step(action)

      # Tell the agent about what just happened.
      agent.update(timestep, action, new_timestep)

      # Book-keeping.
      timestep = new_timestep
Esempio n. 3
0
def run_episode(agent: Agent,
                env: Environment,
                action_repeat: int = 1,
                update: bool = False):
    start_time = time()
    episode_steps = 0
    episode_return = 0
    timestep = env.reset()
    agent.observe_first(timestep)
    while not timestep.last():
        action = agent.select_action(timestep.observation)
        for _ in range(action_repeat):
            timestep = env.step(action)
            agent.observe(action, next_timestep=timestep)
            episode_steps += 1
            episode_return += timestep.reward
            if update:
                agent.update()
            if timestep.last():
                break

    steps_per_second = episode_steps / (time() - start_time)
    result = {
        'episode_length': episode_steps,
        'episode_return': episode_return,
        'steps_per_second': steps_per_second,
    }
    return result
Esempio n. 4
0
def run(
    agent: base.Agent,
    env: dm_env.Environment,
    num_episodes: int,
    eval_mode: bool = False,
) -> base.Agent:
    wandb.init(project="dqn")
    logging.info(
        "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}"
        .format("evaluating" if eval_mode else "training", agent, env,
                num_episodes))
    for episode in range(num_episodes):
        print(
            "Starting episode number {}/{}\t\t\t".format(
                episode, num_episodes - 1),
            end="\r",
        )
        wandb.log({"Episode": episode})
        # initialise environment
        timestep = env.reset()
        while not timestep.last():
            # policy
            action = agent.select_action(timestep)
            # step environment
            new_timestep = env.step(tuple(action))
            wandb.log({"Reward": new_timestep.reward})
            # update
            if not eval_mode:
                loss = agent.update(timestep, action, new_timestep)
                if loss is not None:
                    wandb.log({"Bellman MSE": float(loss)})
                wandb.log({"Iteration": agent.iteration})
            # prepare next
            timestep = new_timestep
    return agent
Esempio n. 5
0
def run_loop(
    agent: Agent,
    environment: dm_env.Environment,
    max_steps_per_episode: int = 0,
    yield_before_reset: bool = False,
) -> Iterable[Tuple[dm_env.Environment, Optional[dm_env.TimeStep], Agent,
                    Optional[Action]]]:
    """Repeatedly alternates step calls on environment and agent.

    At time `t`, `t + 1` environment timesteps and `t + 1` agent steps have been
    seen in the current episode. `t` resets to `0` for the next episode.

    Args:
      agent: Agent to be run, has methods `step(timestep)` and `reset()`.
      environment: Environment to run, has methods `step(action)` and `reset()`.
      max_steps_per_episode: If positive, when time t reaches this value within an
        episode, the episode is truncated.
      yield_before_reset: Whether to additionally yield `(environment, None,
        agent, None)` before the agent and environment is reset at the start of
        each episode.

    Yields:
      Tuple `(environment, timestep_t, agent, a_t)` where
      `a_t = agent.step(timestep_t)`.
    """
    while True:  # For each episode.
        if yield_before_reset:
            yield environment, None, agent, None,

        t = 0
        agent.reset()
        timestep_t = environment.reset()  # timestep_0.

        while True:  # For each step in the current episode.
            a_t = agent.step(timestep_t)
            yield environment, timestep_t, agent, a_t

            # Update t after one environment step and agent step and relabel.
            t += 1
            a_tm1 = a_t
            timestep_t = environment.step(a_tm1)

            if max_steps_per_episode > 0 and t >= max_steps_per_episode:
                assert t == max_steps_per_episode
                timestep_t = timestep_t._replace(
                    step_type=dm_env.StepType.LAST)

            if timestep_t.last():
                unused_a_t = agent.step(
                    timestep_t)  # Extra agent step, action ignored.
                yield environment, timestep_t, agent, None
                break
Esempio n. 6
0
 def run(
     self,
     env: dm_env.Environment,
     num_episodes: int,
     eval: bool = False,
 ) -> Loss:
     logging.info(
         "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}"
         .format("evaluating" if eval else "training", self, env,
                 num_episodes))
     logging.info(
         "The hyperparameters for the current experiment are {}".format(
             self.hparams._asdict()))
     for episode in range(num_episodes):
         print(
             "Episode {}/{}\t\t\t".format(episode, num_episodes - 1),
             end="\r",
         )
         #  initialise environment
         episode_reward = 0.0
         timestep = env.reset()
         while not timestep.last():
             #  apply policy
             action = self.policy(timestep)
             #  observe new state
             new_timestep = self.observe(env, timestep, action)
             episode_reward += new_timestep.reward
             print(
                 "Episode reward {}\t\t".format(episode_reward),
                 end="\r",
             )
             #  update policy
             loss = None
             if not eval:
                 loss = self.update(timestep, action, new_timestep)
             #  log update
             if self.logging:
                 self.log(timestep, action, new_timestep, loss)
             # prepare next iteration
             timestep = new_timestep
     return loss
Esempio n. 7
0
def actor(server: Connection, client: Connection, env: dm_env.Environment):
    def _step(env, a: int):
        timestep = env.step(a)
        if timestep.last():
            timestep = env.reset()
        return timestep

    def _step_async(env, a: int, buffer: mp.Queue):
        timestep = env.step(a)
        buffer.put(timestep)
        print(buffer.qsize())
        if timestep.last():
            timestep = env.reset()
        return

    #  close copy of server connection from client process
    #  see: https://stackoverflow.com/q/8594909/6655465
    server.close()
    #  switch case command
    try:
        while True:
            cmd, data = client.recv()
            if cmd == "step":
                client.send(_step(env, data))
            elif cmd == "step_async":
                client.send(_step_async(env, *data))
            elif cmd == "reset":
                client.send(env.reset())
            elif cmd == "render":
                client.send(env.render(data))
            elif cmd == "close":
                client.send(env.close())
                break
            else:
                raise NotImplementedError("Command {} is not implemented".format(cmd))
    except KeyboardInterrupt:
        logging.info("SubprocVecEnv actor: got KeyboardInterrupt")
    finally:
        env.close()
Esempio n. 8
0
    def __init__(
        self,
        agent: agent_lib.Agent,
        env: dm_env.Environment,
        unroll_length: int,
        learner: learner_lib.Learner,
        rng_seed: int = 42,
        logger=None,
    ):
        self._agent = agent
        self._env = env
        self._unroll_length = unroll_length
        self._learner = learner
        self._timestep = env.reset()
        self._agent_state = agent.initial_state(None)
        self._traj = []
        self._rng_key = jax.random.PRNGKey(rng_seed)

        if logger is None:
            logger = util.NullLogger()
        self._logger = logger

        self._episode_return = 0.