def __init__( self, agent: agent_lib.Agent, rng_key, opt: optix.InitUpdate, batch_size: int, discount_factor: float, frames_per_iter: int, max_abs_reward: float = 0, logger=None, ): if jax.device_count() > 1: warnings.warn('Note: the impala example will only take advantage of a ' 'single accelerator.') self._agent = agent self._opt = opt self._batch_size = batch_size self._discount_factor = discount_factor self._frames_per_iter = frames_per_iter self._max_abs_reward = max_abs_reward # Data pipeline objects. self._done = False self._host_q = queue.Queue(maxsize=self._batch_size) self._device_q = queue.Queue(maxsize=1) # Prepare the parameters to be served to actors. params = agent.initial_params(rng_key) self._params_for_actor = (0, jax.device_get(params)) # Set up logging. if logger is None: logger = util.NullLogger() self._logger = logger
def __init__( self, agent: agent_lib.Agent, env: dm_env.Environment, unroll_length: int, learner: learner_lib.Learner, rng_seed: int = 42, logger=None, ): self._agent = agent self._env = env self._unroll_length = unroll_length self._learner = learner self._timestep = env.reset() self._agent_state = agent.initial_state(None) self._traj = [] self._rng_key = jax.random.PRNGKey(rng_seed) if logger is None: logger = util.NullLogger() self._logger = logger self._episode_return = 0.