Exemplo n.º 1
0
def test_rollout_based(env: SimEnv, policy: Policy):
    ro = rollout(env, policy, record_dts=True)

    if isinstance(policy, LinearPolicy):
        plot_features(ro, policy)
    elif isinstance(policy, PotentialBasedPolicy):
        plot_potentials(ro)
    else:
        plot_observations_actions_rewards(ro)
        plot_observations(ro)
        plot_actions(ro, env)
        plot_rewards(ro)
        draw_dts(ro.dts_policy, ro.dts_step, ro.dts_remainder, y_top_lim=5)
Exemplo n.º 2
0
def after_rollout_query(
    env: Env, policy: Policy, rollout: StepSequence
) -> Tuple[bool, Optional[np.ndarray], Optional[dict]]:
    """
    Ask the user what to do after a rollout has been animated.

    :param env: environment used for the rollout
    :param policy: policy used for the rollout
    :param rollout: collected data from the rollout
    :return: done flag, initial state, and domain parameters
    """
    # Fist entry contains hotkey, second the info text
    options = [
        ["C", "continue simulation (with domain randomization)"],
        ["N", "set domain parameters to nominal values, and continue"],
        ["F", "fix the initial state"],
        ["I", "print information about environment (including randomizer), and policy"],
        ["S", "set a domain parameter explicitly"],
        ["P", "plot all observations, actions, and rewards"],
        ["PS [indices]", "plot all states, or selected ones by passing separated integers"],
        ["PO [indices]", "plot all observations, or selected ones by passing separated integers"],
        ["PA", "plot actions"],
        ["PR", "plot rewards"],
        ["PF", "plot features (for linear policy)"],
        ["PPOT", "plot potentials, stimuli, and actions (for potential-based policies)"],
        ["PDT", "plot time deltas (profiling of a real system)"],
        ["E", "exit"],
    ]

    # Ask for user input
    ans = input(tabulate(options, tablefmt="simple") + "\n").lower()

    if ans == "c" or ans == "":
        # We don't have to do anything here since the env will be reset at the beginning of the next rollout
        return False, None, None

    elif ans == "f":
        try:
            if isinstance(inner_env(env), RealEnv):
                raise pyrado.TypeErr(given=inner_env(env), expected_type=SimEnv)
            elif isinstance(inner_env(env), SimEnv):
                # Get the user input
                usr_inp = input(
                    f"Enter the {env.obs_space.flat_dim}-dim initial state "
                    f"(format: each dim separated by a whitespace):\n"
                )
                state = list(map(float, usr_inp.split()))
                if isinstance(state, list):
                    state = np.array(state)
                    if state.shape != env.obs_space.shape:
                        raise pyrado.ShapeErr(given=state, expected_match=env.obs_space)
                else:
                    raise pyrado.TypeErr(given=state, expected_type=list)
                return False, state, {}
        except (pyrado.TypeErr, pyrado.ShapeErr):
            return after_rollout_query(env, policy, rollout)

    elif ans == "n":
        # Get nominal domain parameters
        if isinstance(inner_env(env), SimEnv):
            dp_nom = inner_env(env).get_nominal_domain_param()
            if typed_env(env, ActDelayWrapper) is not None:
                # There is an ActDelayWrapper in the env chain
                dp_nom["act_delay"] = 0
        else:
            dp_nom = None
        return False, None, dp_nom

    elif ans == "i":
        # Print the information and return to the query
        print(env)
        if hasattr(env, "randomizer"):
            print(env.randomizer)
        print(policy)
        return after_rollout_query(env, policy, rollout)

    elif ans == "p":
        plot_observations_actions_rewards(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pa":
        plot_actions(rollout, env)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "po":
        plot_observations(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif "po" in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_observations(rollout, idcs_sel=idcs)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "ps":
        plot_states(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif "ps" in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_states(rollout, idcs_sel=idcs)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pf":
        plot_features(rollout, policy)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pp":
        draw_policy_params(policy, env.spec, annotate=False)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pr":
        plot_rewards(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pdt":
        draw_dts(rollout.dts_policy, rollout.dts_step, rollout.dts_remainder)
        plt.show()
        return (after_rollout_query(env, policy, rollout),)

    elif ans == "ppot":
        plot_potentials(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "s":
        if isinstance(env, SimEnv):
            dp = env.get_nominal_domain_param()
            for k, v in dp.items():
                dp[k] = [v]  # cast float to list of one element to make it iterable for tabulate
            print("These are the nominal domain parameters:")
            print(tabulate(dp, headers="keys", tablefmt="simple"))

        # Get the user input
        strs = input("Enter one new domain parameter\n(format: key whitespace value):\n")
        try:
            param = dict(str.split() for str in strs.splitlines())
            # Cast the values of the param dict from str to float
            for k, v in param.items():
                param[k] = float(v)
            return False, None, param
        except (ValueError, KeyError):
            print_cbt(f"Could not parse {strs} into a dict.", "r")
            after_rollout_query(env, policy, rollout)

    elif ans == "e":
        env.close()
        return True, None, {}  # breaks the outer while loop

    else:
        return after_rollout_query(env, policy, rollout)  # recursion
Exemplo n.º 3
0
if __name__ == '__main__':
    # Set up environment
    dt = 1 / 5000.
    max_steps = 5000
    env = QQubeRcsSim(
        physicsEngine='Bullet',  # Bullet or Vortex
        dt=dt,
        max_steps=max_steps,
        max_dist_force=None)
    print_domain_params(env.domain_param)

    # Set up policy
    policy = TimePolicy(env.spec, lambda t: [1.],
                        dt)  # constant acceleration with 1. rad/s**2

    # Simulate
    ro = rollout(
        env,
        policy,
        render_mode=RenderMode(video=True),
        reset_kwargs=dict(init_state=np.array([0, 3 / 180 * np.pi, 0, 0.])))

    # Plot
    print(
        f'After {max_steps*dt} s of accelerating with 1. rad/s**2, we should be at {max_steps*dt} rad/s'
    )
    print(
        f'Difference: {max_steps*dt - ro.observations[-1][2]} rad/s (mind the swinging pendulum)'
    )
    plot_observations_actions_rewards(ro)
Exemplo n.º 4
0
def after_rollout_query(env: Env, policy: Policy, rollout: StepSequence) -> tuple:
    """
    Ask the user what to do after a rollout has been animated.

    :param env: environment used for the rollout
    :param policy: policy used for the rollout
    :param rollout: collected data from the rollout
    :return: done flag, initial state, and domain parameters
    """
    # Fist entry contains hotkey, second the info text
    options = [
        ['C', 'continue simulation (with domain randomization)'],
        ['N', 'set domain parameters to nominal values and continue'],
        ['F', 'fix the initial state'],
        ['S', 'set a domain parameter explicitly'],
        ['P', 'plot all observations, actions, and rewards'],
        ['PA', 'plot actions'],
        ['PR', 'plot rewards'],
        ['PO', 'plot all observations'],
        ['PO idcs', 'plot selected observations (integers separated by whitespaces)'],
        ['PF', 'plot features (for linear policy)'],
        ['PP', 'plot policy parameters (not suggested for many parameters)'],
        ['PDT', 'plot time deltas (profiling of a real system)'],
        ['PPOT', 'plot potentials, stimuli, and actions'],
        ['E', 'exit']
    ]

    # Ask for user input
    ans = input(tabulate(options, tablefmt='simple') + '\n').lower()

    if ans == 'c' or ans == '':
        # We don't have to do anything here since the env will be reset at the beginning of the next rollout
        return False, None, None

    elif ans == 'n':
        # Get nominal domain parameters
        if isinstance(inner_env(env), SimEnv):
            dp_nom = inner_env(env).get_nominal_domain_param()
            if typed_env(env, ActDelayWrapper) is not None:
                # There is an ActDelayWrapper in the env chain
                dp_nom['act_delay'] = 0
        else:
            dp_nom = None
        return False, None, dp_nom

    elif ans == 'p':
        plot_observations_actions_rewards(rollout)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pa':
        plot_actions(rollout, env)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'po':
        plot_observations(rollout)
        return after_rollout_query(env, policy, rollout)

    elif 'po' in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_observations(rollout, idcs_sel=idcs)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pf':
        plot_features(rollout, policy)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pp':
        from matplotlib import pyplot as plt
        render_policy_params(policy, env.spec, annotate=False)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pr':
        plot_rewards(rollout)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pdt':
        plot_dts(rollout.dts_policy, rollout.dts_step, rollout.dts_remainder)
        return after_rollout_query(env, policy, rollout),

    elif ans == 'ppot':
        plot_potentials(rollout)
        return after_rollout_query(env, policy, rollout)

    elif ans == 's':
        if isinstance(env, SimEnv):
            dp = env.get_nominal_domain_param()
            for k, v in dp.items():
                dp[k] = [v]  # cast float to list of one element to make it iterable for tabulate
            print('These are the nominal domain parameters:')
            print(tabulate(dp, headers="keys", tablefmt='simple'))

        # Get the user input
        strs = input('Enter one new domain parameter\n(format: key whitespace value):\n')
        try:
            param = dict(str.split() for str in strs.splitlines())
            # Cast the values of the param dict from str to float
            for k, v in param.items():
                param[k] = float(v)
            return False, None, param
        except (ValueError, KeyError):
            print_cbt(f'Could not parse {strs} into a dict.', 'r')
            after_rollout_query(env, policy, rollout)

    elif ans == 'f':
        try:
            if isinstance(inner_env(env), RealEnv):
                raise pyrado.TypeErr(given=inner_env(env), expected_type=SimEnv)
            elif isinstance(inner_env(env), SimEnv):
                # Get the user input
                str = input(f'Enter the {env.obs_space.flat_dim}-dim initial state'
                            f'(format: each dim separated by a whitespace):\n')
                state = list(map(float, str.split()))
                if isinstance(state, list):
                    state = np.array(state)
                    if state.shape != env.obs_space.shape:
                        raise pyrado.ShapeErr(given=state, expected_match=env.obs_space)
                else:
                    raise pyrado.TypeErr(given=state, expected_type=list)
                return False, state, {}
        except (pyrado.TypeErr, pyrado.ShapeErr):
            return after_rollout_query(env, policy, rollout)

    elif ans == 'e':
        env.close()
        return True, None, {}  # breaks the outer while loop

    else:
        return after_rollout_query(env, policy, rollout)  # recursion