def test_add_data(data_format): ro = StepSequence(rewards=rewards, observations=observations, actions=actions, policy_infos=policy_infos, hidden=hidden, data_format=data_format) # Add a data field ro.add_data('return', discounted_value(ro, 0.9)) assert hasattr(ro, 'return') # Query new data field from steps assert abs(ro[2]['return'] - -86.675) < 0.01
def test_add_data(mock_data, data_format: str): rewards, states, observations, actions, hidden, policy_infos = mock_data ro = StepSequence( rewards=rewards, observations=observations, states=states, actions=actions, policy_infos=policy_infos, hidden=hidden, data_format=data_format, ) # Add a data field ro.add_data("return", discounted_value(ro, 0.9)) assert hasattr(ro, "return") # Query new data field from steps assert abs(ro[2]["return"] - -86.675) < 0.01
def rollout( env: Env, policy: Union[nn.Module, Policy, Callable], eval: bool = False, max_steps: Optional[int] = None, reset_kwargs: Optional[dict] = None, render_mode: RenderMode = RenderMode(), render_step: int = 1, no_reset: bool = False, no_close: bool = False, record_dts: bool = False, stop_on_done: bool = True, seed: Optional[int] = None, sub_seed: Optional[int] = None, sub_sub_seed: Optional[int] = None, ) -> StepSequence: """ Perform a rollout (i.e. sample a trajectory) in the given environment using given policy. :param env: environment to use (`SimEnv` or `RealEnv`) :param policy: policy to determine the next action given the current observation. This policy may be wrapped by an exploration strategy. :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to PyTorch `Module`. :param max_steps: maximum number of time steps, if `None` the environment's property is used :param reset_kwargs: keyword arguments passed to environment's reset function :param render_mode: determines if the user sees an animation, console prints, or nothing :param render_step: rendering interval, renders every step if set to 1 :param no_reset: do not reset the environment before running the rollout :param no_close: do not close (and disconnect) the environment after running the rollout :param record_dts: flag if the time intervals of different parts of one step should be recorded (for debugging) :param stop_on_done: set to false to ignore the environment's done flag (for debugging) :param seed: seed value for the random number generators, pass `None` for no seeding :return paths of the observations, actions, rewards, and information about the environment as well as the policy """ # Check the input if not isinstance(env, Env): raise pyrado.TypeErr(given=env, expected_type=Env) # Don't restrain policy type, can be any callable if not isinstance(eval, bool): raise pyrado.TypeErr(given=eval, expected_type=bool) # The max_steps argument is checked by the environment's setter if not (isinstance(reset_kwargs, dict) or reset_kwargs is None): raise pyrado.TypeErr(given=reset_kwargs, expected_type=dict) if not isinstance(render_mode, RenderMode): raise pyrado.TypeErr(given=render_mode, expected_type=RenderMode) # Initialize the paths obs_hist = [] act_hist = [] act_app_hist = [] rew_hist = [] state_hist = [] env_info_hist = [] t_hist = [] if isinstance(policy, Policy): if policy.is_recurrent: hidden_hist = [] # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly if isinstance(getattr(policy, "policy", policy), PotentialBasedPolicy): pot_hist = [] stim_ext_hist = [] stim_int_hist = [] elif isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy): head_2_hist = [] if record_dts: dt_policy_hist = [] dt_step_hist = [] dt_remainder_hist = [] # Override the number of steps to execute if max_steps is not None: env.max_steps = max_steps # Set all rngs' seeds (call before resetting) if seed is not None: pyrado.set_seed(seed, sub_seed, sub_sub_seed) # Reset the environment and pass the kwargs if reset_kwargs is None: reset_kwargs = dict() obs = np.zeros(env.obs_space.shape) if no_reset else env.reset(**reset_kwargs) # Setup rollout information rollout_info = dict(env_name=env.name, env_spec=env.spec) if isinstance(inner_env(env), SimEnv): rollout_info["domain_param"] = env.domain_param if isinstance(policy, Policy): # Reset the policy, i.e. the exploration strategy in case of step-based exploration. # In case the environment is a simulation, the current domain parameters are passed to the policy. This allows # the policy policy to update it's internal model, e.g. for the energy-based swing-up controllers if isinstance(env, SimEnv): policy.reset(domain_param=env.domain_param) else: policy.reset() # Set dropout and batch normalization layers to the right mode if eval: policy.eval() else: policy.train() # Check for recurrent policy, which requires initializing the hidden state if policy.is_recurrent: hidden = policy.init_hidden() # Initialize animation env.render(render_mode, render_step=1) # Initialize the main loop variables done = False t = 0.0 # time starts at zero t_hist.append(t) if record_dts: t_post_step = time.time() # first sample of remainder is useless # ---------- # Begin loop # ---------- # Terminate if the environment signals done, it also keeps track of the time while not (done and stop_on_done) and env.curr_step < env.max_steps: # Record step start time if record_dts or render_mode.video: t_start = time.time() # dual purpose if record_dts: dt_remainder = t_start - t_post_step # Check observations if np.isnan(obs).any(): env.render(render_mode, render_step=1) raise pyrado.ValueErr( msg=f"At least one observation value is NaN!" + tabulate( [list(env.obs_space.labels), [*color_validity(obs, np.invert(np.isnan(obs)))]], headers="firstrow" ) ) # Get the agent's action obs_to = to.from_numpy(obs).type(to.get_default_dtype()) # policy operates on PyTorch tensors with to.no_grad(): if isinstance(policy, Policy): if policy.is_recurrent: if isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy): act_to, head_2_to, hidden_next = policy(obs_to, hidden) else: act_to, hidden_next = policy(obs_to, hidden) else: if isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy): act_to, head_2_to = policy(obs_to) else: act_to = policy(obs_to) else: # If the policy ist not of type Policy, it should still operate on PyTorch tensors act_to = policy(obs_to) act = act_to.detach().cpu().numpy() # environment operates on numpy arrays # Check actions if np.isnan(act).any(): env.render(render_mode, render_step=1) raise pyrado.ValueErr( msg=f"At least one action value is NaN!" + tabulate( [list(env.act_space.labels), [*color_validity(act, np.invert(np.isnan(act)))]], headers="firstrow" ) ) # Record time after the action was calculated if record_dts: t_post_policy = time.time() # Ask the environment to perform the simulation step state = env.state.copy() obs_next, rew, done, env_info = env.step(act) # Get the potentially clipped action, i.e. the one that was actually done in the environment act_app = env.limit_act(act) # Record time after the step i.e. the send and receive is completed if record_dts: t_post_step = time.time() dt_policy = t_post_policy - t_start dt_step = t_post_step - t_post_policy # Record data obs_hist.append(obs) act_hist.append(act) act_app_hist.append(act_app) rew_hist.append(rew) state_hist.append(state) env_info_hist.append(env_info) if record_dts: dt_policy_hist.append(dt_policy) dt_step_hist.append(dt_step) dt_remainder_hist.append(dt_remainder) t += dt_policy + dt_step + dt_remainder else: t += env.dt t_hist.append(t) if isinstance(policy, Policy): if policy.is_recurrent: hidden_hist.append(hidden) hidden = hidden_next # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly if isinstance(getattr(policy, "policy", policy), PotentialBasedPolicy): pot_hist.append(hidden) stim_ext_hist.append(getattr(policy, "policy", policy).stimuli_external.detach().cpu().numpy()) stim_int_hist.append(getattr(policy, "policy", policy).stimuli_internal.detach().cpu().numpy()) elif isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy): head_2_hist.append(head_2_to) # Store the observation for next step (if done, this is the final observation) obs = obs_next # Render if wanted (actually renders the next state) env.render(render_mode, render_step) if render_mode.video: do_sleep = True if pyrado.mujoco_loaded: from pyrado.environments.mujoco.base import MujocoSimEnv if isinstance(env, MujocoSimEnv): # MuJoCo environments seem to crash on time.sleep() do_sleep = False if do_sleep: # Measure time spent and sleep if needed t_end = time.time() t_sleep = env.dt + t_start - t_end if t_sleep > 0: time.sleep(t_sleep) # -------- # End loop # -------- if not no_close: # Disconnect from EnvReal instance (does nothing for EnvSim instances) env.close() # Add final observation to observations list obs_hist.append(obs) state_hist.append(env.state.copy()) # Return result object res = StepSequence( observations=obs_hist, actions=act_hist, actions_applied=act_app_hist, rewards=rew_hist, states=state_hist, time=t_hist, rollout_info=rollout_info, env_infos=env_info_hist, complete=True, # the rollout function always returns complete paths ) # Add special entries to the resulting rollout if isinstance(policy, Policy): if policy.is_recurrent: res.add_data("hidden_states", hidden_hist) if isinstance(getattr(policy, "policy", policy), PotentialBasedPolicy): res.add_data("potentials", pot_hist) res.add_data("stimuli_external", stim_ext_hist) res.add_data("stimuli_internal", stim_int_hist) elif isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy): res.add_data("head_2", head_2_hist) if record_dts: res.add_data("dts_policy", dt_policy_hist) res.add_data("dts_step", dt_step_hist) res.add_data("dts_remainder", dt_remainder_hist) return res
def rollout(env: Env, policy: [nn.Module, Policy], eval: bool = False, max_steps: int = None, reset_kwargs: dict = None, render_mode: RenderMode = RenderMode(), render_step: int = 1, bernoulli_reset: float = None, no_reset: bool = False, no_close: bool = False, record_dts: bool = False, stop_on_done: bool = True) -> StepSequence: """ Perform a rollout (i.e. sample a trajectory) in the given environment using given policy. :param env: environment to use (`SimEnv` or `RealEnv`) :param policy: policy to determine the next action given the current observation. This policy may be wrapped by an exploration strategy. :param eval: flag if the rollout is executed during training (`False`) or during evaluation (`True`) :param max_steps: maximum number of time steps, if `None` the environment's property is used :param reset_kwargs: keyword arguments passed to environment's reset function :param render_mode: determines if the user sees an animation, console prints, or nothing :param render_step: rendering interval, renders every step if set to 1 :param bernoulli_reset: probability for resetting after the current time step :param no_reset: do not reset the environment before running the rollout :param no_close: do not close (and disconnect) the environment after running the rollout :param record_dts: flag if the time intervals of different parts of one step should be recorded (for debugging) :param stop_on_done: set to false to ignore the environments's done flag (for debugging) :return paths of the observations, actions, rewards, and information about the environment as well as the policy """ # Check the input if not isinstance(env, Env): raise pyrado.TypeErr(given=env, expected_type=Env) # Don't restrain policy type, can be any callable if not isinstance(eval, bool): raise pyrado.TypeErr(given=eval, expected_type=bool) # The max_steps argument is checked by the environment's setter if not (isinstance(reset_kwargs, dict) or reset_kwargs is None): raise pyrado.TypeErr(given=reset_kwargs, expected_type=dict) if not isinstance(render_mode, RenderMode): raise pyrado.TypeErr(given=render_mode, expected_type=RenderMode) # Initialize the paths obs_hist = [] act_hist = [] rew_hist = [] env_info_hist = [] if policy.is_recurrent: hidden_hist = [] # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly if isinstance(getattr(policy, 'policy', policy), (ADNPolicy, NFPolicy)): pot_hist = [] stim_ext_hist = [] stim_int_hist = [] elif isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy): head_2_hist = [] if record_dts: dt_policy_hist = [] dt_step_hist = [] dt_remainder_hist = [] # Override the number of steps to execute if max_steps is not None: env.max_steps = max_steps # Reset the environment and pass the kwargs if reset_kwargs is None: reset_kwargs = {} if not no_reset: obs = env.reset(**reset_kwargs) else: obs = np.zeros(env.obs_space.shape) if isinstance(policy, Policy): # Reset the policy / the exploration strategy policy.reset() # Set dropout and batch normalization layers to the right mode if eval: policy.eval() else: policy.train() # Check for recurrent policy, which requires special handling if policy.is_recurrent: # Initialize hidden state var hidden = policy.init_hidden() # Setup rollout information rollout_info = dict(env_spec=env.spec) if isinstance(inner_env(env), SimEnv): rollout_info['domain_param'] = env.domain_param # Initialize animation env.render(render_mode, render_step=1) # Initialize the main loop variables done = False if record_dts: t_post_step = time.time() # first sample of remainder is useless # ---------- # Begin loop # ---------- # Terminate if the environment signals done, it also keeps track of the time while not (done and stop_on_done) and env.curr_step < env.max_steps: # Record step start time if record_dts or render_mode.video: t_start = time.time() # dual purpose if record_dts: dt_remainder = t_start - t_post_step # Check observations if np.isnan(obs).any(): env.render(render_mode, render_step=1) raise pyrado.ValueErr( msg=f'At least one observation value is NaN!' + tabulate([list(env.obs_space.labels), [*color_validity(obs, np.invert(np.isnan(obs)))]], headers='firstrow') ) # Get the agent's action obs_to = to.from_numpy(obs).type(to.get_default_dtype()) # policy operates on PyTorch tensors with to.no_grad(): if policy.is_recurrent: if isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy): act_to, head_2_to, hidden_next = policy(obs_to, hidden) else: act_to, hidden_next = policy(obs_to, hidden) else: if isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy): act_to, head_2_to = policy(obs_to) else: act_to = policy(obs_to) # act_to = (to.tensor([-3.6915228, 31.47042, -6.827999, 11.602707]) @ obs_to).view(-1) # act_to = (to.tensor([-0.42, 18.45, -0.53, 1.53]) @ obs_to).view(-1) # act_to = (to.tensor([-0.2551887, 9.8527975, -4.421094, 10.82632]) @ obs_to).view(-1) # act_to = (to.tensor([ 0.18273291 , 3.829101 , -1.4158, 5.5001416]) @ obs_to).view(-1) # act_to = to.tensor([1.0078554 , 4.221323 , 0.032006 , 4.909644, -2.201612]) @ obs_to # act_to = to.tensor([1.89549804, 4.74797034, -0.09684278, 5.51203606, -2.80852473]) @ obs_to # act_to = to.tensor([1.3555347 , 3.8478632, -0.04043245 , 7.40247 , -3.580207]) @ obs_to + \ # 0.1 * np.random.randn() # print(act_to) act = act_to.detach().cpu().numpy() # environment operates on numpy arrays # Check actions if np.isnan(act).any(): env.render(render_mode, render_step=1) raise pyrado.ValueErr( msg=f'At least one observation value is NaN!' + tabulate([list(env.act_space.labels), [*color_validity(act, np.invert(np.isnan(act)))]], headers='firstrow') ) # Record time after the action was calculated if record_dts: t_post_policy = time.time() # Ask the environment to perform the simulation step obs_next, rew, done, env_info = env.step(act) # Record time after the step i.e. the send and receive is completed if record_dts: t_post_step = time.time() dt_policy = t_post_policy - t_start dt_step = t_post_step - t_post_policy # Record data obs_hist.append(obs) act_hist.append(act) rew_hist.append(rew) env_info_hist.append(env_info) if record_dts: dt_policy_hist.append(dt_policy) dt_step_hist.append(dt_step) dt_remainder_hist.append(dt_remainder) if policy.is_recurrent: hidden_hist.append(hidden) hidden = hidden_next # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly if isinstance(getattr(policy, 'policy', policy), (ADNPolicy, NFPolicy)): pot_hist.append(getattr(policy, 'policy', policy).potentials.detach().numpy()) stim_ext_hist.append(getattr(policy, 'policy', policy).stimuli_external.detach().numpy()) stim_int_hist.append(getattr(policy, 'policy', policy).stimuli_internal.detach().numpy()) elif isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy): head_2_hist.append(head_2_to) # Store the observation for next step (if done, this is the final observation) obs = obs_next # Render if wanted (actually renders the next state) env.render(render_mode, render_step) if render_mode.video: do_sleep = True if pyrado.mujoco_available: from pyrado.environments.mujoco.base import MujocoSimEnv if isinstance(env, MujocoSimEnv): # MuJoCo environments seem to crash on time.sleep() do_sleep = False if do_sleep: # Measure time spent and sleep if needed t_end = time.time() t_sleep = env.dt + t_start - t_end if t_sleep > 0: time.sleep(t_sleep) # Stochastic reset to make the MDP ergodic (e.g. used for REPS) if bernoulli_reset is not None: assert 0. <= bernoulli_reset <= 1. # Stop the rollout with probability bernoulli_reset (most common choice is 1 - gamma) if binomial(1, bernoulli_reset): # The complete=True in the returned StepSequence sets the last done element to True break # -------- # End loop # -------- if not no_close: # Disconnect from EnvReal instance (does nothing for EnvSim instances) env.close() # Add final observation to observations list obs_hist.append(obs) # Return result object res = StepSequence( observations=obs_hist, actions=act_hist, rewards=rew_hist, rollout_info=rollout_info, env_infos=env_info_hist, complete=True # the rollout function always returns complete paths ) # Add special entries to the resulting rollout if policy.is_recurrent: res.add_data('hidden_states', hidden_hist) if isinstance(getattr(policy, 'policy', policy), (ADNPolicy, NFPolicy)): res.add_data('potentials', pot_hist) res.add_data('stimuli_external', stim_ext_hist) res.add_data('stimuli_internal', stim_int_hist) elif isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy): res.add_data('head_2', head_2_hist) if record_dts: res.add_data('dts_policy', dt_policy_hist) res.add_data('dts_step', dt_step_hist) res.add_data('dts_remainder', dt_remainder_hist) return res