예제 #1
0
def test_add_data(data_format):
    ro = StepSequence(rewards=rewards,
                      observations=observations,
                      actions=actions,
                      policy_infos=policy_infos,
                      hidden=hidden,
                      data_format=data_format)
    # Add a data field
    ro.add_data('return', discounted_value(ro, 0.9))
    assert hasattr(ro, 'return')

    # Query new data field from steps
    assert abs(ro[2]['return'] - -86.675) < 0.01
예제 #2
0
def test_add_data(mock_data, data_format: str):
    rewards, states, observations, actions, hidden, policy_infos = mock_data

    ro = StepSequence(
        rewards=rewards,
        observations=observations,
        states=states,
        actions=actions,
        policy_infos=policy_infos,
        hidden=hidden,
        data_format=data_format,
    )
    # Add a data field
    ro.add_data("return", discounted_value(ro, 0.9))
    assert hasattr(ro, "return")

    # Query new data field from steps
    assert abs(ro[2]["return"] - -86.675) < 0.01
예제 #3
0
def rollout(
    env: Env,
    policy: Union[nn.Module, Policy, Callable],
    eval: bool = False,
    max_steps: Optional[int] = None,
    reset_kwargs: Optional[dict] = None,
    render_mode: RenderMode = RenderMode(),
    render_step: int = 1,
    no_reset: bool = False,
    no_close: bool = False,
    record_dts: bool = False,
    stop_on_done: bool = True,
    seed: Optional[int] = None,
    sub_seed: Optional[int] = None,
    sub_sub_seed: Optional[int] = None,
) -> StepSequence:
    """
    Perform a rollout (i.e. sample a trajectory) in the given environment using given policy.

    :param env: environment to use (`SimEnv` or `RealEnv`)
    :param policy: policy to determine the next action given the current observation.
                   This policy may be wrapped by an exploration strategy.
    :param eval: pass `False` if the rollout is executed during training, else `True`. Forwarded to PyTorch `Module`.
    :param max_steps: maximum number of time steps, if `None` the environment's property is used
    :param reset_kwargs: keyword arguments passed to environment's reset function
    :param render_mode: determines if the user sees an animation, console prints, or nothing
    :param render_step: rendering interval, renders every step if set to 1
    :param no_reset: do not reset the environment before running the rollout
    :param no_close: do not close (and disconnect) the environment after running the rollout
    :param record_dts: flag if the time intervals of different parts of one step should be recorded (for debugging)
    :param stop_on_done: set to false to ignore the environment's done flag (for debugging)
    :param seed: seed value for the random number generators, pass `None` for no seeding
    :return paths of the observations, actions, rewards, and information about the environment as well as the policy
    """
    # Check the input
    if not isinstance(env, Env):
        raise pyrado.TypeErr(given=env, expected_type=Env)
    # Don't restrain policy type, can be any callable
    if not isinstance(eval, bool):
        raise pyrado.TypeErr(given=eval, expected_type=bool)
    # The max_steps argument is checked by the environment's setter
    if not (isinstance(reset_kwargs, dict) or reset_kwargs is None):
        raise pyrado.TypeErr(given=reset_kwargs, expected_type=dict)
    if not isinstance(render_mode, RenderMode):
        raise pyrado.TypeErr(given=render_mode, expected_type=RenderMode)

    # Initialize the paths
    obs_hist = []
    act_hist = []
    act_app_hist = []
    rew_hist = []
    state_hist = []
    env_info_hist = []
    t_hist = []
    if isinstance(policy, Policy):
        if policy.is_recurrent:
            hidden_hist = []
        # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly
        if isinstance(getattr(policy, "policy", policy), PotentialBasedPolicy):
            pot_hist = []
            stim_ext_hist = []
            stim_int_hist = []
        elif isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy):
            head_2_hist = []
        if record_dts:
            dt_policy_hist = []
            dt_step_hist = []
            dt_remainder_hist = []

    # Override the number of steps to execute
    if max_steps is not None:
        env.max_steps = max_steps

    # Set all rngs' seeds (call before resetting)
    if seed is not None:
        pyrado.set_seed(seed, sub_seed, sub_sub_seed)

    # Reset the environment and pass the kwargs
    if reset_kwargs is None:
        reset_kwargs = dict()
    obs = np.zeros(env.obs_space.shape) if no_reset else env.reset(**reset_kwargs)

    # Setup rollout information
    rollout_info = dict(env_name=env.name, env_spec=env.spec)
    if isinstance(inner_env(env), SimEnv):
        rollout_info["domain_param"] = env.domain_param

    if isinstance(policy, Policy):
        # Reset the policy, i.e. the exploration strategy in case of step-based exploration.
        # In case the environment is a simulation, the current domain parameters are passed to the policy. This allows
        # the policy policy to update it's internal model, e.g. for the energy-based swing-up controllers
        if isinstance(env, SimEnv):
            policy.reset(domain_param=env.domain_param)
        else:
            policy.reset()

        # Set dropout and batch normalization layers to the right mode
        if eval:
            policy.eval()
        else:
            policy.train()

        # Check for recurrent policy, which requires initializing the hidden state
        if policy.is_recurrent:
            hidden = policy.init_hidden()

    # Initialize animation
    env.render(render_mode, render_step=1)

    # Initialize the main loop variables
    done = False
    t = 0.0  # time starts at zero
    t_hist.append(t)
    if record_dts:
        t_post_step = time.time()  # first sample of remainder is useless

    # ----------
    # Begin loop
    # ----------

    # Terminate if the environment signals done, it also keeps track of the time
    while not (done and stop_on_done) and env.curr_step < env.max_steps:
        # Record step start time
        if record_dts or render_mode.video:
            t_start = time.time()  # dual purpose
        if record_dts:
            dt_remainder = t_start - t_post_step

        # Check observations
        if np.isnan(obs).any():
            env.render(render_mode, render_step=1)
            raise pyrado.ValueErr(
                msg=f"At least one observation value is NaN!"
                + tabulate(
                    [list(env.obs_space.labels), [*color_validity(obs, np.invert(np.isnan(obs)))]], headers="firstrow"
                )
            )

        # Get the agent's action
        obs_to = to.from_numpy(obs).type(to.get_default_dtype())  # policy operates on PyTorch tensors
        with to.no_grad():
            if isinstance(policy, Policy):
                if policy.is_recurrent:
                    if isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy):
                        act_to, head_2_to, hidden_next = policy(obs_to, hidden)
                    else:
                        act_to, hidden_next = policy(obs_to, hidden)
                else:
                    if isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy):
                        act_to, head_2_to = policy(obs_to)
                    else:
                        act_to = policy(obs_to)
            else:
                # If the policy ist not of type Policy, it should still operate on PyTorch tensors
                act_to = policy(obs_to)

        act = act_to.detach().cpu().numpy()  # environment operates on numpy arrays

        # Check actions
        if np.isnan(act).any():
            env.render(render_mode, render_step=1)
            raise pyrado.ValueErr(
                msg=f"At least one action value is NaN!"
                + tabulate(
                    [list(env.act_space.labels), [*color_validity(act, np.invert(np.isnan(act)))]], headers="firstrow"
                )
            )

        # Record time after the action was calculated
        if record_dts:
            t_post_policy = time.time()

        # Ask the environment to perform the simulation step
        state = env.state.copy()
        obs_next, rew, done, env_info = env.step(act)

        # Get the potentially clipped action, i.e. the one that was actually done in the environment
        act_app = env.limit_act(act)

        # Record time after the step i.e. the send and receive is completed
        if record_dts:
            t_post_step = time.time()
            dt_policy = t_post_policy - t_start
            dt_step = t_post_step - t_post_policy

        # Record data
        obs_hist.append(obs)
        act_hist.append(act)
        act_app_hist.append(act_app)
        rew_hist.append(rew)
        state_hist.append(state)
        env_info_hist.append(env_info)
        if record_dts:
            dt_policy_hist.append(dt_policy)
            dt_step_hist.append(dt_step)
            dt_remainder_hist.append(dt_remainder)
            t += dt_policy + dt_step + dt_remainder
        else:
            t += env.dt
        t_hist.append(t)
        if isinstance(policy, Policy):
            if policy.is_recurrent:
                hidden_hist.append(hidden)
                hidden = hidden_next
            # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly
            if isinstance(getattr(policy, "policy", policy), PotentialBasedPolicy):
                pot_hist.append(hidden)
                stim_ext_hist.append(getattr(policy, "policy", policy).stimuli_external.detach().cpu().numpy())
                stim_int_hist.append(getattr(policy, "policy", policy).stimuli_internal.detach().cpu().numpy())
            elif isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy):
                head_2_hist.append(head_2_to)

        # Store the observation for next step (if done, this is the final observation)
        obs = obs_next

        # Render if wanted (actually renders the next state)
        env.render(render_mode, render_step)
        if render_mode.video:
            do_sleep = True
            if pyrado.mujoco_loaded:
                from pyrado.environments.mujoco.base import MujocoSimEnv

                if isinstance(env, MujocoSimEnv):
                    # MuJoCo environments seem to crash on time.sleep()
                    do_sleep = False
            if do_sleep:
                # Measure time spent and sleep if needed
                t_end = time.time()
                t_sleep = env.dt + t_start - t_end
                if t_sleep > 0:
                    time.sleep(t_sleep)

    # --------
    # End loop
    # --------

    if not no_close:
        # Disconnect from EnvReal instance (does nothing for EnvSim instances)
        env.close()

    # Add final observation to observations list
    obs_hist.append(obs)
    state_hist.append(env.state.copy())

    # Return result object
    res = StepSequence(
        observations=obs_hist,
        actions=act_hist,
        actions_applied=act_app_hist,
        rewards=rew_hist,
        states=state_hist,
        time=t_hist,
        rollout_info=rollout_info,
        env_infos=env_info_hist,
        complete=True,  # the rollout function always returns complete paths
    )

    # Add special entries to the resulting rollout
    if isinstance(policy, Policy):
        if policy.is_recurrent:
            res.add_data("hidden_states", hidden_hist)
        if isinstance(getattr(policy, "policy", policy), PotentialBasedPolicy):
            res.add_data("potentials", pot_hist)
            res.add_data("stimuli_external", stim_ext_hist)
            res.add_data("stimuli_internal", stim_int_hist)
        elif isinstance(getattr(policy, "policy", policy), TwoHeadedPolicy):
            res.add_data("head_2", head_2_hist)
    if record_dts:
        res.add_data("dts_policy", dt_policy_hist)
        res.add_data("dts_step", dt_step_hist)
        res.add_data("dts_remainder", dt_remainder_hist)

    return res
예제 #4
0
def rollout(env: Env,
            policy: [nn.Module, Policy],
            eval: bool = False,
            max_steps: int = None,
            reset_kwargs: dict = None,
            render_mode: RenderMode = RenderMode(),
            render_step: int = 1,
            bernoulli_reset: float = None,
            no_reset: bool = False,
            no_close: bool = False,
            record_dts: bool = False,
            stop_on_done: bool = True) -> StepSequence:
    """
    Perform a rollout (i.e. sample a trajectory) in the given environment using given policy.

    :param env: environment to use (`SimEnv` or `RealEnv`)
    :param policy: policy to determine the next action given the current observation.
                   This policy may be wrapped by an exploration strategy.
    :param eval: flag if the rollout is executed during training (`False`) or during evaluation (`True`)
    :param max_steps: maximum number of time steps, if `None` the environment's property is used
    :param reset_kwargs: keyword arguments passed to environment's reset function
    :param render_mode: determines if the user sees an animation, console prints, or nothing
    :param render_step: rendering interval, renders every step if set to 1
    :param bernoulli_reset: probability for resetting after the current time step
    :param no_reset: do not reset the environment before running the rollout
    :param no_close: do not close (and disconnect) the environment after running the rollout
    :param record_dts: flag if the time intervals of different parts of one step should be recorded (for debugging)
    :param stop_on_done: set to false to ignore the environments's done flag (for debugging)
    :return paths of the observations, actions, rewards, and information about the environment as well as the policy
    """
    # Check the input
    if not isinstance(env, Env):
        raise pyrado.TypeErr(given=env, expected_type=Env)
    # Don't restrain policy type, can be any callable
    if not isinstance(eval, bool):
        raise pyrado.TypeErr(given=eval, expected_type=bool)
    # The max_steps argument is checked by the environment's setter
    if not (isinstance(reset_kwargs, dict) or reset_kwargs is None):
        raise pyrado.TypeErr(given=reset_kwargs, expected_type=dict)
    if not isinstance(render_mode, RenderMode):
        raise pyrado.TypeErr(given=render_mode, expected_type=RenderMode)

    # Initialize the paths
    obs_hist = []
    act_hist = []
    rew_hist = []
    env_info_hist = []
    if policy.is_recurrent:
        hidden_hist = []
    # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly
    if isinstance(getattr(policy, 'policy', policy), (ADNPolicy, NFPolicy)):
        pot_hist = []
        stim_ext_hist = []
        stim_int_hist = []
    elif isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy):
        head_2_hist = []
    if record_dts:
        dt_policy_hist = []
        dt_step_hist = []
        dt_remainder_hist = []

    # Override the number of steps to execute
    if max_steps is not None:
        env.max_steps = max_steps

    # Reset the environment and pass the kwargs
    if reset_kwargs is None:
        reset_kwargs = {}
    if not no_reset:
        obs = env.reset(**reset_kwargs)
    else:
        obs = np.zeros(env.obs_space.shape)

    if isinstance(policy, Policy):
        # Reset the policy / the exploration strategy
        policy.reset()

        # Set dropout and batch normalization layers to the right mode
        if eval:
            policy.eval()
        else:
            policy.train()

    # Check for recurrent policy, which requires special handling
    if policy.is_recurrent:
        # Initialize hidden state var
        hidden = policy.init_hidden()

    # Setup rollout information
    rollout_info = dict(env_spec=env.spec)
    if isinstance(inner_env(env), SimEnv):
        rollout_info['domain_param'] = env.domain_param

    # Initialize animation
    env.render(render_mode, render_step=1)

    # Initialize the main loop variables
    done = False
    if record_dts:
        t_post_step = time.time()  # first sample of remainder is useless

    # ----------
    # Begin loop
    # ----------

    # Terminate if the environment signals done, it also keeps track of the time
    while not (done and stop_on_done) and env.curr_step < env.max_steps:
        # Record step start time
        if record_dts or render_mode.video:
            t_start = time.time()  # dual purpose
        if record_dts:
            dt_remainder = t_start - t_post_step

        # Check observations
        if np.isnan(obs).any():
            env.render(render_mode, render_step=1)
            raise pyrado.ValueErr(
                msg=f'At least one observation value is NaN!' +
                    tabulate([list(env.obs_space.labels),
                              [*color_validity(obs, np.invert(np.isnan(obs)))]], headers='firstrow')
            )

        # Get the agent's action
        obs_to = to.from_numpy(obs).type(to.get_default_dtype())  # policy operates on PyTorch tensors
        with to.no_grad():
            if policy.is_recurrent:
                if isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy):
                    act_to, head_2_to, hidden_next = policy(obs_to, hidden)
                else:
                    act_to, hidden_next = policy(obs_to, hidden)
            else:
                if isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy):
                    act_to, head_2_to = policy(obs_to)
                else:
                    act_to = policy(obs_to)

                    # act_to = (to.tensor([-3.6915228, 31.47042,   -6.827999,  11.602707]) @ obs_to).view(-1)


                    # act_to = (to.tensor([-0.42, 18.45, -0.53, 1.53]) @ obs_to).view(-1)
                    # act_to = (to.tensor([-0.2551887, 9.8527975, -4.421094, 10.82632]) @ obs_to).view(-1)



                    # act_to = (to.tensor([ 0.18273291 , 3.829101 ,  -1.4158,      5.5001416]) @ obs_to).view(-1)


                    # act_to = to.tensor([1.0078554 , 4.221323 ,  0.032006 ,  4.909644,  -2.201612]) @ obs_to

                    # act_to = to.tensor([1.89549804,  4.74797034, -0.09684278,  5.51203606, -2.80852473]) @ obs_to

                    # act_to = to.tensor([1.3555347 ,  3.8478632,  -0.04043245 , 7.40247 ,   -3.580207]) @ obs_to + \
                    #     0.1 * np.random.randn()

                    # print(act_to)


        act = act_to.detach().cpu().numpy()  # environment operates on numpy arrays

        # Check actions
        if np.isnan(act).any():
            env.render(render_mode, render_step=1)
            raise pyrado.ValueErr(
                msg=f'At least one observation value is NaN!' +
                    tabulate([list(env.act_space.labels),
                              [*color_validity(act, np.invert(np.isnan(act)))]], headers='firstrow')
            )

        # Record time after the action was calculated
        if record_dts:
            t_post_policy = time.time()

        # Ask the environment to perform the simulation step
        obs_next, rew, done, env_info = env.step(act)

        # Record time after the step i.e. the send and receive is completed
        if record_dts:
            t_post_step = time.time()
            dt_policy = t_post_policy - t_start
            dt_step = t_post_step - t_post_policy

        # Record data
        obs_hist.append(obs)
        act_hist.append(act)
        rew_hist.append(rew)
        env_info_hist.append(env_info)
        if record_dts:
            dt_policy_hist.append(dt_policy)
            dt_step_hist.append(dt_step)
            dt_remainder_hist.append(dt_remainder)
        if policy.is_recurrent:
            hidden_hist.append(hidden)
            hidden = hidden_next
        # If an ExplStrat is passed use the policy property, if a Policy is passed use it directly
        if isinstance(getattr(policy, 'policy', policy), (ADNPolicy, NFPolicy)):
            pot_hist.append(getattr(policy, 'policy', policy).potentials.detach().numpy())
            stim_ext_hist.append(getattr(policy, 'policy', policy).stimuli_external.detach().numpy())
            stim_int_hist.append(getattr(policy, 'policy', policy).stimuli_internal.detach().numpy())
        elif isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy):
            head_2_hist.append(head_2_to)

        # Store the observation for next step (if done, this is the final observation)
        obs = obs_next

        # Render if wanted (actually renders the next state)
        env.render(render_mode, render_step)

        if render_mode.video:
            do_sleep = True
            if pyrado.mujoco_available:
                from pyrado.environments.mujoco.base import MujocoSimEnv
                if isinstance(env, MujocoSimEnv):
                    # MuJoCo environments seem to crash on time.sleep()
                    do_sleep = False
            if do_sleep:
                # Measure time spent and sleep if needed
                t_end = time.time()
                t_sleep = env.dt + t_start - t_end
                if t_sleep > 0:
                    time.sleep(t_sleep)

        # Stochastic reset to make the MDP ergodic (e.g. used for REPS)
        if bernoulli_reset is not None:
            assert 0. <= bernoulli_reset <= 1.
            # Stop the rollout with probability bernoulli_reset (most common choice is 1 - gamma)
            if binomial(1, bernoulli_reset):
                # The complete=True in the returned StepSequence sets the last done element to True
                break

    # --------
    # End loop
    # --------

    if not no_close:
        # Disconnect from EnvReal instance (does nothing for EnvSim instances)
        env.close()

    # Add final observation to observations list
    obs_hist.append(obs)

    # Return result object
    res = StepSequence(
        observations=obs_hist,
        actions=act_hist,
        rewards=rew_hist,
        rollout_info=rollout_info,
        env_infos=env_info_hist,
        complete=True  # the rollout function always returns complete paths
    )

    # Add special entries to the resulting rollout
    if policy.is_recurrent:
        res.add_data('hidden_states', hidden_hist)
    if isinstance(getattr(policy, 'policy', policy), (ADNPolicy, NFPolicy)):
        res.add_data('potentials', pot_hist)
        res.add_data('stimuli_external', stim_ext_hist)
        res.add_data('stimuli_internal', stim_int_hist)
    elif isinstance(getattr(policy, 'policy', policy), TwoHeadedPolicy):
        res.add_data('head_2', head_2_hist)
    if record_dts:
        res.add_data('dts_policy', dt_policy_hist)
        res.add_data('dts_step', dt_step_hist)
        res.add_data('dts_remainder', dt_remainder_hist)

    return res