def run(agent: base.Agent, environment: dm_env.Environment, num_episodes: int, verbose: bool = False) -> None: """Runs an agent on an environment. Note that for bsuite environments, logging is handled internally. Args: agent: The agent to train and evaluate. environment: The environment to train on. num_episodes: Number of episodes to train for. verbose: Whether to also log to terminal. """ if verbose: environment = terminal_logging.wrap_environment( environment, log_every=True) # pytype: disable=wrong-arg-types for _ in range(num_episodes): # Run an episode. timestep = environment.reset() while not timestep.last(): # Generate an action from the agent's policy. action = agent.select_action(timestep) # Step the environment. new_timestep = environment.step(action) # Tell the agent about what just happened. agent.update(timestep, action, new_timestep) # Book-keeping. timestep = new_timestep
def run( agent: base.Agent, env: dm_env.Environment, num_episodes: int, eval_mode: bool = False, ) -> base.Agent: wandb.init(project="dqn") logging.info( "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}" .format("evaluating" if eval_mode else "training", agent, env, num_episodes)) for episode in range(num_episodes): print( "Starting episode number {}/{}\t\t\t".format( episode, num_episodes - 1), end="\r", ) wandb.log({"Episode": episode}) # initialise environment timestep = env.reset() while not timestep.last(): # policy action = agent.select_action(timestep) # step environment new_timestep = env.step(tuple(action)) wandb.log({"Reward": new_timestep.reward}) # update if not eval_mode: loss = agent.update(timestep, action, new_timestep) if loss is not None: wandb.log({"Bellman MSE": float(loss)}) wandb.log({"Iteration": agent.iteration}) # prepare next timestep = new_timestep return agent
def run_episode(agent: Agent, env: Environment, action_repeat: int = 1, update: bool = False): start_time = time() episode_steps = 0 episode_return = 0 timestep = env.reset() agent.observe_first(timestep) while not timestep.last(): action = agent.select_action(timestep.observation) for _ in range(action_repeat): timestep = env.step(action) agent.observe(action, next_timestep=timestep) episode_steps += 1 episode_return += timestep.reward if update: agent.update() if timestep.last(): break steps_per_second = episode_steps / (time() - start_time) result = { 'episode_length': episode_steps, 'episode_return': episode_return, 'steps_per_second': steps_per_second, } return result
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: tf.data.Dataset that produces the same fake N-step transition ReverSample object indefinitely. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = (observation, action, reward, discount, observation) key = np.array(0, np.uint64) probability = np.array(1.0, np.float64) table_size = np.array(1, np.int64) priority = np.array(1.0, np.float64) info = reverb.SampleInfo(key=key, probability=probability, table_size=table_size, priority=priority) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def run(agent: Agent, environment: dm_env.Environment, num_episodes: int, results_dir: str = 'res/default.pkl') -> None: ''' Runs an agent on an enviroment. Args: agent: The agent to train and evaluate. environment: The environment to train on. num_episodes: Number of episodes to train for. verbose: Whether to also log to terminal. ''' for episode in range(num_episodes): #Run an episode. timestep = environment.reset() while not timestep.last(): action = agent.select_action(timestep) print(action) new_timestep = environment.step(action) # Pass the (s, a, r, s')info to the agent. agent.update(timestep, action, new_timestep) # update timestep timestep = new_timestep if (episode + 1) % 100 == 0: print("Episode %d success." % (episode + 1)) if True: torch.save(getattr(agent, '_network'), results_dir)
def make_environment_spec(environment: dm_env.Environment) -> EnvironmentSpec: """Returns an `EnvironmentSpec` describing values used by an environment.""" return EnvironmentSpec( observations=environment.observation_spec(), actions=environment.action_spec(), rewards=environment.reward_spec(), discounts=environment.discount_spec())
def run_loop( agent: Agent, environment: dm_env.Environment, max_steps_per_episode: int = 0, yield_before_reset: bool = False, ) -> Iterable[Tuple[dm_env.Environment, Optional[dm_env.TimeStep], Agent, Optional[Action]]]: """Repeatedly alternates step calls on environment and agent. At time `t`, `t + 1` environment timesteps and `t + 1` agent steps have been seen in the current episode. `t` resets to `0` for the next episode. Args: agent: Agent to be run, has methods `step(timestep)` and `reset()`. environment: Environment to run, has methods `step(action)` and `reset()`. max_steps_per_episode: If positive, when time t reaches this value within an episode, the episode is truncated. yield_before_reset: Whether to additionally yield `(environment, None, agent, None)` before the agent and environment is reset at the start of each episode. Yields: Tuple `(environment, timestep_t, agent, a_t)` where `a_t = agent.step(timestep_t)`. """ while True: # For each episode. if yield_before_reset: yield environment, None, agent, None, t = 0 agent.reset() timestep_t = environment.reset() # timestep_0. while True: # For each step in the current episode. a_t = agent.step(timestep_t) yield environment, timestep_t, agent, a_t # Update t after one environment step and agent step and relabel. t += 1 a_tm1 = a_t timestep_t = environment.step(a_tm1) if max_steps_per_episode > 0 and t >= max_steps_per_episode: assert t == max_steps_per_episode timestep_t = timestep_t._replace( step_type=dm_env.StepType.LAST) if timestep_t.last(): unused_a_t = agent.step( timestep_t) # Extra agent step, action ignored. yield environment, timestep_t, agent, None break
def __init__(self, environment: dm_env.Environment, *, frame_rate: Optional[int] = None, camera_id: Optional[int] = 0, height: int = 240, width: int = 320, playback_speed: float = 1., **kwargs): # Check that we have a mujoco environment (or a wrapper thereof). if not hasattr(environment, '_physics'): raise ValueError( 'MujocoVideoWrapper expects an environment which ' 'exposes a _physics attribute corresponding to a MuJoCo ' 'physics engine') # Compute frame rate if not set. if frame_rate is None: frame_rate = int( round(playback_speed / environment.control_timestep())) super().__init__(environment, frame_rate=frame_rate, **kwargs) self._camera_id = camera_id self._height = height self._width = width
def __init__(self, environment: dm_env.Environment, name_filter: Optional[Sequence[str]] = None): """Initializes a new ConcatObservationWrapper. Args: environment: Environment to wrap. name_filter: Sequence of observation names to keep. None keeps them all. """ super().__init__(environment) observation_spec = environment.observation_spec() if name_filter is None: name_filter = list(observation_spec.keys()) self._obs_names = [ x for x in name_filter if x in observation_spec.keys() ] dummy_obs = _zeros_like(observation_spec) dummy_obs = self._convert_observation(dummy_obs) self._observation_spec = dm_env.specs.BoundedArray( shape=dummy_obs.shape, dtype=dummy_obs.dtype, minimum=-np.inf, maximum=np.inf, name='state')
def __init__( self, env: Environment, alpha: float, lamda: float, n_features: int, logger: Logger = NullLogger(), ): super().__init__() self.alpha = alpha self.lamda = lamda self.w = jnp.zeros(n_features + 1) self.z = jnp.zeros(n_features + 1) self.logger = logger self.policy = EGreedy(env.action_spec, 0.1) # temporary value estimate for the starting state self._actions = jnp.arange(env.action_spec().num_actions) self._q_old = jnp.zeros(env.action_spec().num_values)
def _make_ma_environment_spec( self, environment: dm_env.Environment) -> Dict[str, EnvironmentSpec]: """Returns an `EnvironmentSpec` describing values used by an environment for each agent.""" specs = {} observation_specs = environment.observation_spec() action_specs = environment.action_spec() reward_specs = environment.reward_spec() discount_specs = environment.discount_spec() self.extra_specs = environment.extra_spec() for agent in environment.possible_agents: specs[agent] = EnvironmentSpec( observations=observation_specs[agent], actions=action_specs[agent], rewards=reward_specs[agent], discounts=discount_specs[agent], ) return specs
def assert_env_reset( wrapped_env: dm_env.Environment, dm_env_timestep: dm_env.TimeStep, env_spec: EnvSpec, ) -> None: if env_spec.env_type == EnvType.Parallel: rewards_spec = wrapped_env.reward_spec() expected_rewards = { agent: convert_np_type(rewards_spec[agent].dtype, 0) for agent in wrapped_env.agents } discount_spec = wrapped_env.discount_spec() expected_discounts = { agent: convert_np_type(rewards_spec[agent].dtype, 1) for agent in wrapped_env.agents } Helpers.compare_dicts( dm_env_timestep.reward, expected_rewards, ), "Failed to reset reward." Helpers.compare_dicts( dm_env_timestep.discount, expected_discounts, ), "Failed to reset discount." elif env_spec.env_type == EnvType.Sequential: for agent in wrapped_env.agents: rewards_spec = wrapped_env.reward_spec() expected_reward = convert_np_type(rewards_spec[agent].dtype, 0) discount_spec = wrapped_env.discount_spec() expected_discount = convert_np_type(discount_spec[agent].dtype, 1) assert dm_env_timestep.reward == expected_reward and type( dm_env_timestep.reward) == type( expected_reward), "Failed to reset reward." assert dm_env_timestep.discount == expected_discount and type( dm_env_timestep.discount) == type( expected_discount), "Failed to reset discount."
def observe( self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: int ) -> dm_env.TimeStep: # iterate over the number of steps for t in range(self.hparams.n_steps): # get new MDP state new_timestep = env.step(action) # store transition into the replay buffer self.memory.add(timestep, action, new_timestep, preprocess=self.preprocess) timestep = new_timestep return timestep
def __init__( self, environment: dm_env.Environment, additional_discount: float = 0.99, max_abs_reward: Optional[float] = 1.0, resize_shape: Optional[Tuple[int, int]] = (84, 84), num_action_repeats: int = 4, num_pooled_frames: int = 2, zero_discount_on_life_loss: bool = True, num_stacked_frames: int = 4, grayscaling: bool = True, ): rgb_spec, unused_lives_spec = environment.observation_spec() if rgb_spec.shape[2] != 3: raise ValueError( 'This wrapper assumes interleaved pixel observations with shape ' '(height, width, channels).') if int(environment.action_spec().minimum) != 0: raise ValueError('This wrapper assumes zero-indexed actions.') self._environment = environment self._processor = atari( additional_discount=additional_discount, max_abs_reward=max_abs_reward, resize_shape=resize_shape, num_action_repeats=num_action_repeats, num_pooled_frames=num_pooled_frames, zero_discount_on_life_loss=zero_discount_on_life_loss, num_stacked_frames=num_stacked_frames, grayscaling=grayscaling, ) if grayscaling: self._observation_shape = resize_shape + (num_stacked_frames, ) self._observation_spec_name = 'grayscale' else: self._observation_shape = resize_shape + (3, num_stacked_frames) self._observation_spec_name = 'RGB' self._reset_next_step = True
def transition_iterator( environment: dm_env.Environment ) -> Callable[[int], Iterator[types.Transition]]: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: A callable that given a batch_size returns an iterator with demonstrations. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) dataset = tf.data.Dataset.from_tensors(data).repeat() return lambda batch_size: dataset.batch(batch_size).as_numpy_iterator()
def actor(server: Connection, client: Connection, env: dm_env.Environment): def _step(env, a: int): timestep = env.step(a) if timestep.last(): timestep = env.reset() return timestep def _step_async(env, a: int, buffer: mp.Queue): timestep = env.step(a) buffer.put(timestep) print(buffer.qsize()) if timestep.last(): timestep = env.reset() return # close copy of server connection from client process # see: https://stackoverflow.com/q/8594909/6655465 server.close() # switch case command try: while True: cmd, data = client.recv() if cmd == "step": client.send(_step(env, data)) elif cmd == "step_async": client.send(_step_async(env, *data)) elif cmd == "reset": client.send(env.reset()) elif cmd == "render": client.send(env.render(data)) elif cmd == "close": client.send(env.close()) break else: raise NotImplementedError("Command {} is not implemented".format(cmd)) except KeyboardInterrupt: logging.info("SubprocVecEnv actor: got KeyboardInterrupt") finally: env.close()
def transition_dataset(environment: dm_env.Environment) -> tf.data.Dataset: """Fake dataset of Reverb N-step transition samples. Args: environment: Used to create a fake transition by looking at the observation, action, discount and reward specs. Returns: tf.data.Dataset that produces the same fake N-step transition ReverSample object indefinitely. """ observation = environment.observation_spec().generate_value() action = environment.action_spec().generate_value() reward = environment.reward_spec().generate_value() discount = environment.discount_spec().generate_value() data = types.Transition(observation, action, reward, discount, observation) info = tree.map_structure( lambda tf_dtype: tf.ones([], tf_dtype.as_numpy_dtype), reverb.SampleInfo.tf_dtypes()) sample = reverb.ReplaySample(info=info, data=data) return tf.data.Dataset.from_tensors(sample).repeat()
def run( self, env: dm_env.Environment, num_episodes: int, eval: bool = False, ) -> Loss: logging.info( "Starting {} agent {} on environment {}.\nThe scheduled number of episode is {}" .format("evaluating" if eval else "training", self, env, num_episodes)) logging.info( "The hyperparameters for the current experiment are {}".format( self.hparams._asdict())) for episode in range(num_episodes): print( "Episode {}/{}\t\t\t".format(episode, num_episodes - 1), end="\r", ) # initialise environment episode_reward = 0.0 timestep = env.reset() while not timestep.last(): # apply policy action = self.policy(timestep) # observe new state new_timestep = self.observe(env, timestep, action) episode_reward += new_timestep.reward print( "Episode reward {}\t\t".format(episode_reward), end="\r", ) # update policy loss = None if not eval: loss = self.update(timestep, action, new_timestep) # log update if self.logging: self.log(timestep, action, new_timestep, loss) # prepare next iteration timestep = new_timestep return loss
def __init__( self, agent: agent_lib.Agent, env: dm_env.Environment, unroll_length: int, learner: learner_lib.Learner, rng_seed: int = 42, logger=None, ): self._agent = agent self._env = env self._unroll_length = unroll_length self._learner = learner self._timestep = env.reset() self._agent_state = agent.initial_state(None) self._traj = [] self._rng_key = jax.random.PRNGKey(rng_seed) if logger is None: logger = util.NullLogger() self._logger = logger self._episode_return = 0.
def __init__(self, environment: dm_env.Environment, *, max_abs_reward: Optional[float] = None, scale_dims: Optional[Tuple[int, int]] = (84, 84), action_repeats: int = 4, pooled_frames: int = 2, zero_discount_on_life_loss: bool = False, expose_lives_observation: bool = False, num_stacked_frames: int = 4, max_episode_len: Optional[int] = None, to_float: bool = False, grayscaling: bool = True): """Initializes a new AtariWrapper. Args: environment: An Atari environment. max_abs_reward: Maximum absolute reward value before clipping is applied. If set to `None` (default), no clipping is applied. scale_dims: Image size for the rescaling step after grayscaling, given as `(height, width)`. Set to `None` to disable resizing. action_repeats: Number of times to step wrapped environment for each given action. pooled_frames: Number of observations to pool over. Set to 1 to disable frame pooling. zero_discount_on_life_loss: If `True`, sets the discount to zero when the number of lives decreases in in Atari environment. expose_lives_observation: If `False`, the `lives` part of the observation is discarded, otherwise it is kept as part of an observation tuple. This does not affect the `zero_discount_on_life_loss` feature. When enabled, the observation consists of a single pixel array, otherwise it is a tuple (pixel_array, lives). num_stacked_frames: Number of recent (pooled) observations to stack into the returned observation. max_episode_len: Number of frames before truncating episode. By default, there is no maximum length. to_float: If `True`, rescales RGB observations to floats in [0, 1]. grayscaling: If `True` returns a grayscale version of the observations. In this case, the observation is 3D (H, W, num_stacked_frames). If `False` the observations are RGB and have shape (H, W, C, num_stacked_frames). Raises: ValueError: For various invalid inputs. """ if not 1 <= pooled_frames <= action_repeats: raise ValueError("pooled_frames ({}) must be between 1 and " "action_repeats ({}) inclusive".format( pooled_frames, action_repeats)) if zero_discount_on_life_loss: super().__init__(_ZeroDiscountOnLifeLoss(environment)) else: super().__init__(environment) if not max_episode_len: max_episode_len = np.inf self._frame_stacker = frame_stacking.FrameStacker( num_frames=num_stacked_frames) self._action_repeats = action_repeats self._pooled_frames = pooled_frames self._scale_dims = scale_dims self._max_abs_reward = max_abs_reward or np.inf self._to_float = to_float self._expose_lives_observation = expose_lives_observation if scale_dims: self._height, self._width = scale_dims else: spec = environment.observation_spec() self._height, self._width = spec[RGB_INDEX].shape[:2] self._episode_len = 0 self._max_episode_len = max_episode_len self._reset_next_step = True self._grayscaling = grayscaling # Based on underlying observation spec, decide whether lives are to be # included in output observations. observation_spec = self._environment.observation_spec() spec_names = [spec.name for spec in observation_spec] if "lives" in spec_names and spec_names.index("lives") != 1: raise ValueError( "`lives` observation needs to have index 1 in Atari.") self._observation_spec = self._init_observation_spec() self._raw_observation = None
def __init__(self, environment: dm_env.Environment, clip: bool = False): super().__init__(environment) self._action_spec = environment.action_spec() self._clip = clip
def __init__(self, environment: dm_env.Environment): self._environment = environment if int(environment.action_spec()[0].minimum) != 0: raise ValueError( 'This wrapper assumes zero-indexed actions. Use the Atari setting ' 'zero_indexed_actions=\"true\" to get actions in this format.')