Exemple #1
0
  def __init__(
      self,
      agent: agent_lib.Agent,
      rng_key,
      opt: optix.InitUpdate,
      batch_size: int,
      discount_factor: float,
      frames_per_iter: int,
      max_abs_reward: float = 0,
      logger=None,
  ):
    if jax.device_count() > 1:
      warnings.warn('Note: the impala example will only take advantage of a '
                    'single accelerator.')

    self._agent = agent
    self._opt = opt
    self._batch_size = batch_size
    self._discount_factor = discount_factor
    self._frames_per_iter = frames_per_iter
    self._max_abs_reward = max_abs_reward

    # Data pipeline objects.
    self._done = False
    self._host_q = queue.Queue(maxsize=self._batch_size)
    self._device_q = queue.Queue(maxsize=1)

    # Prepare the parameters to be served to actors.
    params = agent.initial_params(rng_key)
    self._params_for_actor = (0, jax.device_get(params))

    # Set up logging.
    if logger is None:
      logger = util.NullLogger()
    self._logger = logger
Exemple #2
0
    def __init__(
        self,
        agent: agent_lib.Agent,
        env: dm_env.Environment,
        unroll_length: int,
        learner: learner_lib.Learner,
        rng_seed: int = 42,
        logger=None,
    ):
        self._agent = agent
        self._env = env
        self._unroll_length = unroll_length
        self._learner = learner
        self._timestep = env.reset()
        self._agent_state = agent.initial_state(None)
        self._traj = []
        self._rng_key = jax.random.PRNGKey(rng_seed)

        if logger is None:
            logger = util.NullLogger()
        self._logger = logger

        self._episode_return = 0.